handle partially quantized models (#76)
* handle partially quantized models - fix for #53 #71 #69 #74 - in order to test the models - I added a default prompt of an appropriate form - while working on the model configuration also added additional stop tokens (#74) - fixed the repetitionPenalty code (#71)
This commit is contained in:
@@ -33,6 +33,12 @@ public struct ModelConfiguration {
|
||||
/// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated
|
||||
public let overrideTokenizer: String?
|
||||
|
||||
/// A reasonable default prompt for the model
|
||||
public let defaultPrompt: String
|
||||
|
||||
/// Additional tokens to use for end of string
|
||||
public let extraEOSTokens: Set<String>
|
||||
|
||||
/// custom preparation logic for the prompt. custom tokenizers provide more capability, but this
|
||||
/// allows some minor formtting changes, e.g. wrapping the user input in the expected prompt
|
||||
/// format
|
||||
@@ -40,21 +46,29 @@ public struct ModelConfiguration {
|
||||
|
||||
public init(
|
||||
id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
|
||||
defaultPrompt: String = "hello",
|
||||
extraEOSTokens: Set<String> = [],
|
||||
preparePrompt: ((String) -> String)? = nil
|
||||
) {
|
||||
self.id = .id(id)
|
||||
self.tokenizerId = tokenizerId
|
||||
self.overrideTokenizer = overrideTokenizer
|
||||
self.defaultPrompt = defaultPrompt
|
||||
self.extraEOSTokens = extraEOSTokens
|
||||
self.preparePrompt = preparePrompt
|
||||
}
|
||||
|
||||
public init(
|
||||
directory: URL, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
|
||||
defaultPrompt: String = "hello",
|
||||
extraEOSTokens: Set<String> = [],
|
||||
preparePrompt: ((String) -> String)? = nil
|
||||
) {
|
||||
self.id = .directory(directory)
|
||||
self.tokenizerId = tokenizerId
|
||||
self.overrideTokenizer = overrideTokenizer
|
||||
self.defaultPrompt = defaultPrompt
|
||||
self.extraEOSTokens = extraEOSTokens
|
||||
self.preparePrompt = preparePrompt
|
||||
}
|
||||
|
||||
@@ -98,11 +112,16 @@ public struct ModelConfiguration {
|
||||
extension ModelConfiguration {
|
||||
|
||||
public static let mistral7B4bit = ModelConfiguration(
|
||||
id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx")
|
||||
id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx",
|
||||
|
||||
// https://www.promptingguide.ai/models/mistral-7b
|
||||
defaultPrompt: "describe the swift language"
|
||||
)
|
||||
|
||||
public static let codeLlama13b4bit = ModelConfiguration(
|
||||
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
|
||||
overrideTokenizer: "PreTrainedTokenizer"
|
||||
overrideTokenizer: "PreTrainedTokenizer",
|
||||
defaultPrompt: "func sortArray(_ array: [Int]) -> String { <FILL_ME> }"
|
||||
) { prompt in
|
||||
// given the prompt: func sortArray(_ array: [Int]) -> String { <FILL_ME> }
|
||||
// the python code produces this (via its custom tokenizer):
|
||||
@@ -111,13 +130,17 @@ extension ModelConfiguration {
|
||||
"<PRE> " + prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") + " <MID>"
|
||||
}
|
||||
|
||||
public static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") {
|
||||
prompt in
|
||||
"Instruct: \(prompt)\nOutput: "
|
||||
}
|
||||
public static let phi4bit = ModelConfiguration(
|
||||
id: "mlx-community/phi-2-hf-4bit-mlx",
|
||||
|
||||
// https://www.promptingguide.ai/models/phi-2
|
||||
defaultPrompt: "Why is the sky blue?"
|
||||
)
|
||||
|
||||
public static let phi34bit = ModelConfiguration(
|
||||
id: "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed"
|
||||
id: "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed",
|
||||
defaultPrompt: "what is the gravity on mars and the moon?",
|
||||
extraEOSTokens: ["<|end|>"]
|
||||
) {
|
||||
prompt in
|
||||
"<s><|user|>\n\(prompt)<|end|>\n<|assistant|>\n"
|
||||
@@ -125,26 +148,35 @@ extension ModelConfiguration {
|
||||
|
||||
public static let gemma2bQuantized = ModelConfiguration(
|
||||
id: "mlx-community/quantized-gemma-2b-it",
|
||||
overrideTokenizer: "PreTrainedTokenizer"
|
||||
overrideTokenizer: "PreTrainedTokenizer",
|
||||
|
||||
// https://www.promptingguide.ai/models/gemma
|
||||
defaultPrompt: "what is the difference between lettuce and cabbage?"
|
||||
|
||||
) { prompt in
|
||||
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
|
||||
}
|
||||
|
||||
public static let qwen205b4bit = ModelConfiguration(
|
||||
id: "mlx-community/Qwen1.5-0.5B-Chat-4bit",
|
||||
overrideTokenizer: "PreTrainedTokenizer"
|
||||
overrideTokenizer: "PreTrainedTokenizer",
|
||||
defaultPrompt: "why is the sky blue?"
|
||||
) { prompt in
|
||||
"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant"
|
||||
}
|
||||
|
||||
public static let openelm270m4bit = ModelConfiguration(
|
||||
id: "mlx-community/OpenELM-270M-Instruct"
|
||||
id: "mlx-community/OpenELM-270M-Instruct",
|
||||
|
||||
// https://huggingface.co/apple/OpenELM
|
||||
defaultPrompt: "Once upon a time there was"
|
||||
) { prompt in
|
||||
"\(prompt)"
|
||||
}
|
||||
|
||||
public static let llama38B4bit = ModelConfiguration(
|
||||
id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
|
||||
id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
|
||||
defaultPrompt: "what is the difference between a fruit and a vegetable?"
|
||||
) {
|
||||
prompt in
|
||||
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
|
||||
|
||||
Reference in New Issue
Block a user