handle partially quantized models (#76)

* handle partially quantized models

- fix for #53 #71 #69 #74
- in order to test the models
	- I added a default prompt of an appropriate form
	- while working on the model configuration also added additional stop tokens (#74)
- fixed the repetitionPenalty code (#71)
This commit is contained in:
David Koski
2024-05-28 16:35:11 -07:00
committed by GitHub
parent 65f4968e5f
commit 9d74afd119
12 changed files with 139 additions and 67 deletions

View File

@@ -44,7 +44,7 @@ struct GenerateArguments: ParsableArguments {
help:
"The message to be processed by the model. Use @path,@path to load from files, e.g. @/tmp/prompt.txt"
)
var prompt = "compare python and swift"
var prompt: String?
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
var maxTokens = 100
@@ -73,7 +73,8 @@ struct GenerateArguments: ParsableArguments {
repetitionContextSize: repetitionContextSize)
}
func resolvePrompt() throws -> String {
func resolvePrompt(configuration: ModelConfiguration) throws -> String {
let prompt = self.prompt ?? configuration.defaultPrompt
if prompt.hasPrefix("@") {
let names = prompt.split(separator: ",").map { String($0.dropFirst()) }
return try names.map { try String(contentsOfFile: $0) }.joined(separator: "\n")
@@ -87,14 +88,17 @@ struct GenerateArguments: ParsableArguments {
) {
MLXRandom.seed(seed)
let prompt = try resolvePrompt()
let prompt = try resolvePrompt(configuration: configuration)
let preparedPrompt = configuration.prepare(prompt: prompt)
let promptTokens = tokenizer.encode(text: preparedPrompt)
return (prompt, promptTokens)
}
func generate(promptTokens: [Int], model: LLMModel, tokenizer: Tokenizer) async
func generate(
promptTokens: [Int], model: LLMModel, tokenizer: Tokenizer,
extraEOSTokens: Set<String>? = nil
) async
-> GenerateResult
{
// track how much we have printed
@@ -102,7 +106,7 @@ struct GenerateArguments: ParsableArguments {
return await LLM.generate(
promptTokens: promptTokens, parameters: generateParameters,
model: model, tokenizer: tokenizer
model: model, tokenizer: tokenizer, extraEOSTokens: extraEOSTokens
) { tokens in
// print any new parts of the string
@@ -226,7 +230,8 @@ struct EvaluateCommand: AsyncParsableCommand {
}
let result = await generate.generate(
promptTokens: promptTokens, model: model, tokenizer: tokenizer)
promptTokens: promptTokens, model: model, tokenizer: tokenizer,
extraEOSTokens: modelConfiguration.extraEOSTokens)
print()
if !generate.quiet {