handle partially quantized models (#76)
* handle partially quantized models - fix for #53 #71 #69 #74 - in order to test the models - I added a default prompt of an appropriate form - while working on the model configuration also added additional stop tokens (#74) - fixed the repetitionPenalty code (#71)
This commit is contained in:
@@ -44,7 +44,7 @@ struct GenerateArguments: ParsableArguments {
|
||||
help:
|
||||
"The message to be processed by the model. Use @path,@path to load from files, e.g. @/tmp/prompt.txt"
|
||||
)
|
||||
var prompt = "compare python and swift"
|
||||
var prompt: String?
|
||||
|
||||
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
|
||||
var maxTokens = 100
|
||||
@@ -73,7 +73,8 @@ struct GenerateArguments: ParsableArguments {
|
||||
repetitionContextSize: repetitionContextSize)
|
||||
}
|
||||
|
||||
func resolvePrompt() throws -> String {
|
||||
func resolvePrompt(configuration: ModelConfiguration) throws -> String {
|
||||
let prompt = self.prompt ?? configuration.defaultPrompt
|
||||
if prompt.hasPrefix("@") {
|
||||
let names = prompt.split(separator: ",").map { String($0.dropFirst()) }
|
||||
return try names.map { try String(contentsOfFile: $0) }.joined(separator: "\n")
|
||||
@@ -87,14 +88,17 @@ struct GenerateArguments: ParsableArguments {
|
||||
) {
|
||||
MLXRandom.seed(seed)
|
||||
|
||||
let prompt = try resolvePrompt()
|
||||
let prompt = try resolvePrompt(configuration: configuration)
|
||||
let preparedPrompt = configuration.prepare(prompt: prompt)
|
||||
let promptTokens = tokenizer.encode(text: preparedPrompt)
|
||||
|
||||
return (prompt, promptTokens)
|
||||
}
|
||||
|
||||
func generate(promptTokens: [Int], model: LLMModel, tokenizer: Tokenizer) async
|
||||
func generate(
|
||||
promptTokens: [Int], model: LLMModel, tokenizer: Tokenizer,
|
||||
extraEOSTokens: Set<String>? = nil
|
||||
) async
|
||||
-> GenerateResult
|
||||
{
|
||||
// track how much we have printed
|
||||
@@ -102,7 +106,7 @@ struct GenerateArguments: ParsableArguments {
|
||||
|
||||
return await LLM.generate(
|
||||
promptTokens: promptTokens, parameters: generateParameters,
|
||||
model: model, tokenizer: tokenizer
|
||||
model: model, tokenizer: tokenizer, extraEOSTokens: extraEOSTokens
|
||||
) { tokens in
|
||||
|
||||
// print any new parts of the string
|
||||
@@ -226,7 +230,8 @@ struct EvaluateCommand: AsyncParsableCommand {
|
||||
}
|
||||
|
||||
let result = await generate.generate(
|
||||
promptTokens: promptTokens, model: model, tokenizer: tokenizer)
|
||||
promptTokens: promptTokens, model: model, tokenizer: tokenizer,
|
||||
extraEOSTokens: modelConfiguration.extraEOSTokens)
|
||||
print()
|
||||
|
||||
if !generate.quiet {
|
||||
|
||||
@@ -275,7 +275,8 @@ struct LoRAEvalCommand: AsyncParsableCommand {
|
||||
|
||||
// generate and print the result
|
||||
let _ = await generate.generate(
|
||||
promptTokens: promptTokens, model: model, tokenizer: tokenizer)
|
||||
promptTokens: promptTokens, model: model, tokenizer: tokenizer,
|
||||
extraEOSTokens: modelConfiguration.extraEOSTokens)
|
||||
print()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user