LLMEval performance (#40)

* notes about performance and some performance improvements (don't update the display for every token)

* swift-format

* Update Applications/LLMEval/README.md

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update Applications/LLMEval/README.md

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
David Koski
2024-03-28 12:00:52 -07:00
committed by GitHub
parent 15b38cd146
commit 0199407d93
2 changed files with 28 additions and 3 deletions

View File

@@ -152,6 +152,11 @@ class LLMEvaluator {
let temperature: Float = 0.6
let maxTokens = 240
/// update the display every N tokens -- 4 looks like it updates continuously
/// and is low overhead. observed ~15% reduction in tokens/s when updating
/// on every token
let displayEveryNTokens = 4
enum LoadState {
case idle
case loaded(LLMModel, Tokenizers.Tokenizer)
@@ -198,7 +203,7 @@ class LLMEvaluator {
let prompt = modelConfiguration.prepare(prompt: prompt)
let promptTokens = MLXArray(tokenizer.encode(text: prompt))
let initTime = Date()
var initTime = Date()
let initDuration = initTime.timeIntervalSince(startTime)
await MainActor.run {
self.stat = "Init: \(String(format: "%.3f", initDuration))s"
@@ -212,6 +217,12 @@ class LLMEvaluator {
for token in TokenIterator(prompt: promptTokens, model: model, temp: temperature) {
let tokenId = token.item(Int.self)
// to match the measurement from the command line we reset the start time
// after the first token is generated (called the prompt time)
if outputTokens.isEmpty {
initTime = Date()
}
if tokenId == tokenizer.unknownTokenId || tokenId == tokenizer.eosTokenId {
break
}
@@ -220,8 +231,10 @@ class LLMEvaluator {
let text = tokenizer.decode(tokens: outputTokens)
// update the output -- this will make the view show the text as it generates
await MainActor.run {
self.output = text
if outputTokens.count % displayEveryNTokens == 0 {
await MainActor.run {
self.output = text
}
}
if outputTokens.count == maxTokens {
@@ -232,7 +245,13 @@ class LLMEvaluator {
let tokenDuration = Date().timeIntervalSince(initTime)
let tokensPerSecond = Double(outputTokens.count) / tokenDuration
// update the text if needed, e.g. we haven't displayed because of displayEveryNTokens
let finalText = tokenizer.decode(tokens: outputTokens)
await MainActor.run {
if finalText != self.output {
self.output = finalText
}
running = false
self.stat += " Tokens/second: \(String(format: "%.3f", tokensPerSecond))"
}