prepare for lora branch (#47)

- remove async llm generation -- this is just doubling our work
	- and does not match the style used in the example applications
- package generation parameters into a struct
- refactor command line arguments into distinct pieces based on their use
	- this will be reusable in the lora commands
This commit is contained in:
David Koski
2024-04-10 10:56:18 -07:00
committed by GitHub
parent cedf73421f
commit 96b94b0df6
5 changed files with 227 additions and 274 deletions

View File

@@ -160,7 +160,7 @@ class LLMEvaluator {
let modelConfiguration = ModelConfiguration.phi4bit
/// parameters controlling the output
let temperature: Float = 0.6
let generateParameters = GenerateParameters(temperature: 0.6)
let maxTokens = 240
/// update the display every N tokens -- 4 looks like it updates continuously
@@ -201,7 +201,6 @@ class LLMEvaluator {
}
func generate(prompt: String) async {
let startTime = Date()
do {
let (model, tokenizer) = try await load()
@@ -212,59 +211,37 @@ class LLMEvaluator {
// augment the prompt as needed
let prompt = modelConfiguration.prepare(prompt: prompt)
let promptTokens = MLXArray(tokenizer.encode(text: prompt))
var initTime = Date()
let initDuration = initTime.timeIntervalSince(startTime)
await MainActor.run {
self.stat = "Init: \(String(format: "%.3f", initDuration))s"
}
let promptTokens = tokenizer.encode(text: prompt)
// each time you generate you will get something new
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
var outputTokens = [Int]()
for token in TokenIterator(prompt: promptTokens, model: model, temp: temperature) {
let tokenId = token.item(Int.self)
// to match the measurement from the command line we reset the start time
// after the first token is generated (called the prompt time)
if outputTokens.isEmpty {
initTime = Date()
}
if tokenId == tokenizer.unknownTokenId || tokenId == tokenizer.eosTokenId {
break
}
outputTokens.append(tokenId)
let text = tokenizer.decode(tokens: outputTokens)
let result = await LLM.generate(
promptTokens: promptTokens, parameters: generateParameters, model: model,
tokenizer: tokenizer
) { tokens in
// update the output -- this will make the view show the text as it generates
if outputTokens.count % displayEveryNTokens == 0 {
if tokens.count % displayEveryNTokens == 0 {
let text = tokenizer.decode(tokens: tokens)
await MainActor.run {
self.output = text
}
}
if outputTokens.count == maxTokens {
break
if tokens.count >= maxTokens {
return .stop
} else {
return .more
}
}
let tokenDuration = Date().timeIntervalSince(initTime)
let tokensPerSecond = Double(outputTokens.count) / tokenDuration
// update the text if needed, e.g. we haven't displayed because of displayEveryNTokens
let finalText = tokenizer.decode(tokens: outputTokens)
await MainActor.run {
if finalText != self.output {
self.output = finalText
if result.output != self.output {
self.output = result.output
}
running = false
self.stat += " Tokens/second: \(String(format: "%.3f", tokensPerSecond))"
self.stat = " Tokens/second: \(String(format: "%.3f", result.tokensPerSecond))"
}
} catch {