Add Mistral NeMo (#97)

* Update Mistral 7B config

* Add Mistral NeMo
This commit is contained in:
Anthony DePasquale
2024-07-25 17:01:26 +02:00
committed by GitHub
parent a2e8d7e469
commit c4fda0e036
2 changed files with 19 additions and 6 deletions

View File

@@ -26,7 +26,7 @@ private class Attention: Module {
let heads = args.attentionHeads let heads = args.attentionHeads
let kvHeads = args.kvHeads let kvHeads = args.kvHeads
let headDim = args.hiddenSize / heads let headDim = args.headDimensions ?? (args.hiddenSize / heads)
self.scale = pow(Float(headDim), -0.5) self.scale = pow(Float(headDim), -0.5)
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false) self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false)
@@ -215,6 +215,7 @@ public struct LlamaConfiguration: Codable {
var hiddenLayers: Int var hiddenLayers: Int
var intermediateSize: Int var intermediateSize: Int
var attentionHeads: Int var attentionHeads: Int
var headDimensions: Int? = nil
var rmsNormEps: Float var rmsNormEps: Float
var vocabularySize: Int var vocabularySize: Int
var kvHeads: Int var kvHeads: Int
@@ -228,6 +229,7 @@ public struct LlamaConfiguration: Codable {
case hiddenLayers = "num_hidden_layers" case hiddenLayers = "num_hidden_layers"
case intermediateSize = "intermediate_size" case intermediateSize = "intermediate_size"
case attentionHeads = "num_attention_heads" case attentionHeads = "num_attention_heads"
case headDimensions = "head_dim"
case rmsNormEps = "rms_norm_eps" case rmsNormEps = "rms_norm_eps"
case vocabularySize = "vocab_size" case vocabularySize = "vocab_size"
case kvHeads = "num_key_value_heads" case kvHeads = "num_key_value_heads"
@@ -251,6 +253,8 @@ public struct LlamaConfiguration: Codable {
Int.self, forKey: LlamaConfiguration.CodingKeys.intermediateSize) Int.self, forKey: LlamaConfiguration.CodingKeys.intermediateSize)
self.attentionHeads = try container.decode( self.attentionHeads = try container.decode(
Int.self, forKey: LlamaConfiguration.CodingKeys.attentionHeads) Int.self, forKey: LlamaConfiguration.CodingKeys.attentionHeads)
self.headDimensions = try container.decodeIfPresent(
Int.self, forKey: LlamaConfiguration.CodingKeys.headDimensions)
self.rmsNormEps = try container.decode( self.rmsNormEps = try container.decode(
Float.self, forKey: LlamaConfiguration.CodingKeys.rmsNormEps) Float.self, forKey: LlamaConfiguration.CodingKeys.rmsNormEps)
self.vocabularySize = try container.decode( self.vocabularySize = try container.decode(

View File

@@ -118,12 +118,19 @@ extension ModelConfiguration {
"<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant\n" "<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant\n"
} }
public static let mistral7B4bit = ModelConfiguration( public static let mistralNeMo4bit = ModelConfiguration(
id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx", id: "mlx-community/Mistral-Nemo-Instruct-2407-4bit",
defaultPrompt: "Explain quaternions."
) { prompt in
"<s>[INST] \(prompt) [/INST] "
}
// https://www.promptingguide.ai/models/mistral-7b public static let mistral7B4bit = ModelConfiguration(
defaultPrompt: "describe the swift language" id: "mlx-community/Mistral-7B-Instruct-v0.3-4bit",
) defaultPrompt: "Describe the Swift language."
) { prompt in
"<s>[INST] \(prompt) [/INST] "
}
public static let codeLlama13b4bit = ModelConfiguration( public static let codeLlama13b4bit = ModelConfiguration(
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX", id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
@@ -213,6 +220,8 @@ extension ModelConfiguration {
case .idle: case .idle:
bootstrapState = .bootstrapping bootstrapState = .bootstrapping
register(configurations: [ register(configurations: [
mistralNeMo4bit,
smolLM_135M_4bit,
mistral7B4bit, mistral7B4bit,
codeLlama13b4bit, codeLlama13b4bit,
phi4bit, phi4bit,