From c4fda0e0364c0fa348308decc22d52176393a151 Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Thu, 25 Jul 2024 17:01:26 +0200 Subject: [PATCH] Add Mistral NeMo (#97) * Update Mistral 7B config * Add Mistral NeMo --- Libraries/LLM/Llama.swift | 6 +++++- Libraries/LLM/Models.swift | 19 ++++++++++++++----- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/Libraries/LLM/Llama.swift b/Libraries/LLM/Llama.swift index c9ef4cc..80b3387 100644 --- a/Libraries/LLM/Llama.swift +++ b/Libraries/LLM/Llama.swift @@ -26,7 +26,7 @@ private class Attention: Module { let heads = args.attentionHeads let kvHeads = args.kvHeads - let headDim = args.hiddenSize / heads + let headDim = args.headDimensions ?? (args.hiddenSize / heads) self.scale = pow(Float(headDim), -0.5) self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false) @@ -215,6 +215,7 @@ public struct LlamaConfiguration: Codable { var hiddenLayers: Int var intermediateSize: Int var attentionHeads: Int + var headDimensions: Int? = nil var rmsNormEps: Float var vocabularySize: Int var kvHeads: Int @@ -228,6 +229,7 @@ public struct LlamaConfiguration: Codable { case hiddenLayers = "num_hidden_layers" case intermediateSize = "intermediate_size" case attentionHeads = "num_attention_heads" + case headDimensions = "head_dim" case rmsNormEps = "rms_norm_eps" case vocabularySize = "vocab_size" case kvHeads = "num_key_value_heads" @@ -251,6 +253,8 @@ public struct LlamaConfiguration: Codable { Int.self, forKey: LlamaConfiguration.CodingKeys.intermediateSize) self.attentionHeads = try container.decode( Int.self, forKey: LlamaConfiguration.CodingKeys.attentionHeads) + self.headDimensions = try container.decodeIfPresent( + Int.self, forKey: LlamaConfiguration.CodingKeys.headDimensions) self.rmsNormEps = try container.decode( Float.self, forKey: LlamaConfiguration.CodingKeys.rmsNormEps) self.vocabularySize = try container.decode( diff --git a/Libraries/LLM/Models.swift b/Libraries/LLM/Models.swift index 229fb94..73ab0ab 100644 --- a/Libraries/LLM/Models.swift +++ b/Libraries/LLM/Models.swift @@ -118,12 +118,19 @@ extension ModelConfiguration { "<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant\n" } - public static let mistral7B4bit = ModelConfiguration( - id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx", + public static let mistralNeMo4bit = ModelConfiguration( + id: "mlx-community/Mistral-Nemo-Instruct-2407-4bit", + defaultPrompt: "Explain quaternions." + ) { prompt in + "[INST] \(prompt) [/INST] " + } - // https://www.promptingguide.ai/models/mistral-7b - defaultPrompt: "describe the swift language" - ) + public static let mistral7B4bit = ModelConfiguration( + id: "mlx-community/Mistral-7B-Instruct-v0.3-4bit", + defaultPrompt: "Describe the Swift language." + ) { prompt in + "[INST] \(prompt) [/INST] " + } public static let codeLlama13b4bit = ModelConfiguration( id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX", @@ -213,6 +220,8 @@ extension ModelConfiguration { case .idle: bootstrapState = .bootstrapping register(configurations: [ + mistralNeMo4bit, + smolLM_135M_4bit, mistral7B4bit, codeLlama13b4bit, phi4bit,