prepare for lora branch (#47)
- remove async llm generation -- this is just doubling our work - and does not match the style used in the example applications - package generation parameters into a struct - refactor command line arguments into distinct pieces based on their use - this will be reusable in the lora commands
This commit is contained in:
@@ -160,7 +160,7 @@ class LLMEvaluator {
|
||||
let modelConfiguration = ModelConfiguration.phi4bit
|
||||
|
||||
/// parameters controlling the output
|
||||
let temperature: Float = 0.6
|
||||
let generateParameters = GenerateParameters(temperature: 0.6)
|
||||
let maxTokens = 240
|
||||
|
||||
/// update the display every N tokens -- 4 looks like it updates continuously
|
||||
@@ -201,7 +201,6 @@ class LLMEvaluator {
|
||||
}
|
||||
|
||||
func generate(prompt: String) async {
|
||||
let startTime = Date()
|
||||
do {
|
||||
let (model, tokenizer) = try await load()
|
||||
|
||||
@@ -212,59 +211,37 @@ class LLMEvaluator {
|
||||
|
||||
// augment the prompt as needed
|
||||
let prompt = modelConfiguration.prepare(prompt: prompt)
|
||||
let promptTokens = MLXArray(tokenizer.encode(text: prompt))
|
||||
|
||||
var initTime = Date()
|
||||
let initDuration = initTime.timeIntervalSince(startTime)
|
||||
await MainActor.run {
|
||||
self.stat = "Init: \(String(format: "%.3f", initDuration))s"
|
||||
}
|
||||
let promptTokens = tokenizer.encode(text: prompt)
|
||||
|
||||
// each time you generate you will get something new
|
||||
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
|
||||
|
||||
var outputTokens = [Int]()
|
||||
|
||||
for token in TokenIterator(prompt: promptTokens, model: model, temp: temperature) {
|
||||
let tokenId = token.item(Int.self)
|
||||
|
||||
// to match the measurement from the command line we reset the start time
|
||||
// after the first token is generated (called the prompt time)
|
||||
if outputTokens.isEmpty {
|
||||
initTime = Date()
|
||||
}
|
||||
|
||||
if tokenId == tokenizer.unknownTokenId || tokenId == tokenizer.eosTokenId {
|
||||
break
|
||||
}
|
||||
|
||||
outputTokens.append(tokenId)
|
||||
let text = tokenizer.decode(tokens: outputTokens)
|
||||
|
||||
let result = await LLM.generate(
|
||||
promptTokens: promptTokens, parameters: generateParameters, model: model,
|
||||
tokenizer: tokenizer
|
||||
) { tokens in
|
||||
// update the output -- this will make the view show the text as it generates
|
||||
if outputTokens.count % displayEveryNTokens == 0 {
|
||||
if tokens.count % displayEveryNTokens == 0 {
|
||||
let text = tokenizer.decode(tokens: tokens)
|
||||
await MainActor.run {
|
||||
self.output = text
|
||||
}
|
||||
}
|
||||
|
||||
if outputTokens.count == maxTokens {
|
||||
break
|
||||
if tokens.count >= maxTokens {
|
||||
return .stop
|
||||
} else {
|
||||
return .more
|
||||
}
|
||||
}
|
||||
|
||||
let tokenDuration = Date().timeIntervalSince(initTime)
|
||||
let tokensPerSecond = Double(outputTokens.count) / tokenDuration
|
||||
|
||||
// update the text if needed, e.g. we haven't displayed because of displayEveryNTokens
|
||||
let finalText = tokenizer.decode(tokens: outputTokens)
|
||||
|
||||
await MainActor.run {
|
||||
if finalText != self.output {
|
||||
self.output = finalText
|
||||
if result.output != self.output {
|
||||
self.output = result.output
|
||||
}
|
||||
running = false
|
||||
self.stat += " Tokens/second: \(String(format: "%.3f", tokensPerSecond))"
|
||||
self.stat = " Tokens/second: \(String(format: "%.3f", result.tokensPerSecond))"
|
||||
}
|
||||
|
||||
} catch {
|
||||
|
||||
Reference in New Issue
Block a user