From 63531bc5fa8033999cfae923c89af003c29531b9 Mon Sep 17 00:00:00 2001 From: Anchen Date: Wed, 27 Mar 2024 06:43:10 +1100 Subject: [PATCH] feat: add command r model support (#35) * feat: add command r model support --- Libraries/LLM/Cohere.swift | 238 ++++++++++++++++++ Libraries/LLM/Configuration.swift | 5 + Libraries/LLM/Tokenizer.swift | 3 +- mlx-swift-examples.xcodeproj/project.pbxproj | 4 + .../xcshareddata/xcschemes/llm-tool.xcscheme | 6 +- 5 files changed, 254 insertions(+), 2 deletions(-) create mode 100644 Libraries/LLM/Cohere.swift diff --git a/Libraries/LLM/Cohere.swift b/Libraries/LLM/Cohere.swift new file mode 100644 index 0000000..8ccbfbe --- /dev/null +++ b/Libraries/LLM/Cohere.swift @@ -0,0 +1,238 @@ +import Foundation +import MLX +import MLXFast +import MLXNN + +// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/cohere.py + +private class Attention: Module { + + let args: CohereConfiguration + let scale: Float + + @ModuleInfo(key: "q_proj") var wq: Linear + @ModuleInfo(key: "k_proj") var wk: Linear + @ModuleInfo(key: "v_proj") var wv: Linear + @ModuleInfo(key: "o_proj") var wo: Linear + + let rope: RoPE + + public init(_ args: CohereConfiguration) { + self.args = args + + let dim = args.hiddenSize + let heads = args.attentionHeads + let kvHeads = args.kvHeads + + let headDim = args.hiddenSize / heads + self.scale = pow(Float(headDim), -0.5) + + self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false) + self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: false) + self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: false) + self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false) + + self.rope = RoPE( + dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil + ) -> (MLXArray, (MLXArray, MLXArray)) { + let (B, L) = (x.dim(0), x.dim(1)) + + var queries = wq(x) + var keys = wk(x) + var values = wv(x) + + // prepare the queries, keys and values for the attention computation + queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3) + keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) + values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) + + if let (keyCache, valueCache) = cache { + queries = rope(queries, offset: keyCache.dim(2)) + keys = rope(keys, offset: keyCache.dim(2)) + keys = concatenated([keyCache, keys], axis: 2) + values = concatenated([valueCache, values], axis: 2) + } else { + queries = rope(queries) + keys = rope(keys) + } + + let output = MLXFast.scaledDotProductAttention( + queries: queries, keys: keys, values: values, scale: scale, mask: mask + ) + .transposed(0, 2, 1, 3) + .reshaped(B, L, -1) + + return (wo(output), (keys, values)) + } +} + +private class MLP: Module, UnaryLayer { + + @ModuleInfo(key: "gate_proj") var gate: Linear + @ModuleInfo(key: "down_proj") var down: Linear + @ModuleInfo(key: "up_proj") var up: Linear + + public init(dimensions: Int, hiddenDimensions: Int) { + self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false) + + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + down(silu(gate(x)) * up(x)) + } +} + +private class TransformerBlock: Module { + + @ModuleInfo(key: "self_attn") var attention: Attention + let mlp: MLP + + @ModuleInfo(key: "input_layernorm") var inputLayerNorm: LayerNorm + + public init(_ args: CohereConfiguration) { + self._attention.wrappedValue = Attention(args) + self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize) + self._inputLayerNorm.wrappedValue = LayerNorm( + dimensions: args.hiddenSize, eps: args.layerNormEps) + + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil + ) -> (MLXArray, (MLXArray, MLXArray)) { + let h = inputLayerNorm(x) + let (attnH, cache) = attention(h, mask: mask, cache: cache) + let ffH = mlp(h) + return (attnH + ffH + x, cache) + } +} + +public class CohereModelInner: Module { + + @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + + fileprivate let layers: [TransformerBlock] + let norm: LayerNorm + + public init(_ args: CohereConfiguration) { + precondition(args.vocabularySize > 0) + + self._embedTokens.wrappedValue = Embedding( + embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) + + self.layers = (0 ..< args.hiddenLayers) + .map { _ in + TransformerBlock(args) + } + self.norm = LayerNorm(dimensions: args.hiddenSize, eps: args.layerNormEps) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> ( + MLXArray, [(MLXArray, MLXArray)] + ) { + var h = embedTokens(inputs) + + var mask: MLXArray? = nil + if h.dim(1) > 1 { + mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1)) + mask = mask?.asType(h.dtype) + } + + var newCache = [(MLXArray, MLXArray)]() + + for (i, layer) in layers.enumerated() { + var cacheUpdate: (MLXArray, MLXArray) + (h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i]) + newCache.append(cacheUpdate) + } + + return (norm(h), newCache) + } +} + +public class CohereModel: Module, LLMModel { + + public let vocabularySize: Int + let model: CohereModelInner + let logitScale: Float + + public init(_ args: CohereConfiguration) { + self.vocabularySize = args.vocabularySize + self.model = CohereModelInner(args) + self.logitScale = args.logitScale + } + + public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> ( + MLXArray, [(MLXArray, MLXArray)] + ) { + var (out, cache) = model(inputs, cache: cache) + out = matmul(out, model.embedTokens.weight.T) + out = out * self.logitScale + return (out, cache) + } +} + +public struct CohereConfiguration: Codable { + + var hiddenSize: Int + var hiddenLayers: Int + var intermediateSize: Int + var attentionHeads: Int + var layerNormEps: Float + var vocabularySize: Int + var kvHeads: Int + var ropeTheta: Float = 8000000.0 + var ropeTraditional: Bool = true + var ropeScaling: [String: StringOrNumber]? = nil + var logitScale: Float + + enum CodingKeys: String, CodingKey { + case hiddenSize = "hidden_size" + case hiddenLayers = "num_hidden_layers" + case intermediateSize = "intermediate_size" + case attentionHeads = "num_attention_heads" + case kvHeads = "num_key_value_heads" + case ropeTheta = "rope_theta" + case vocabularySize = "vocab_size" + case layerNormEps = "layer_norm_eps" + case logitScale = "logit_scale" + case ropeTraditional = "rope_traditional" + case ropeScaling = "rope_scaling" + } + + public init(from decoder: Decoder) throws { + // custom implementation to handle optional keys with required values + let container: KeyedDecodingContainer = + try decoder.container( + keyedBy: CohereConfiguration.CodingKeys.self) + + self.hiddenSize = try container.decode( + Int.self, forKey: CohereConfiguration.CodingKeys.hiddenSize) + self.hiddenLayers = try container.decode( + Int.self, forKey: CohereConfiguration.CodingKeys.hiddenLayers) + self.intermediateSize = try container.decode( + Int.self, forKey: CohereConfiguration.CodingKeys.intermediateSize) + self.attentionHeads = try container.decode( + Int.self, forKey: CohereConfiguration.CodingKeys.attentionHeads) + self.layerNormEps = try container.decode( + Float.self, forKey: CohereConfiguration.CodingKeys.layerNormEps) + self.vocabularySize = try container.decode( + Int.self, forKey: CohereConfiguration.CodingKeys.vocabularySize) + self.kvHeads = try container.decode( + Int.self, forKey: CohereConfiguration.CodingKeys.kvHeads) + self.ropeTheta = + try container.decodeIfPresent( + Float.self, forKey: CohereConfiguration.CodingKeys.ropeTheta) + ?? 8000000.0 + self.ropeScaling = try container.decodeIfPresent( + [String: StringOrNumber].self, forKey: CohereConfiguration.CodingKeys.ropeScaling) + self.logitScale = try container.decode( + Float.self, forKey: CohereConfiguration.CodingKeys.logitScale) + } +} diff --git a/Libraries/LLM/Configuration.swift b/Libraries/LLM/Configuration.swift index f207258..d1fdedc 100644 --- a/Libraries/LLM/Configuration.swift +++ b/Libraries/LLM/Configuration.swift @@ -33,6 +33,7 @@ public enum ModelType: String, Codable { case gemma case qwen2 case starcoder2 + case cohere func createModel(configuration: URL) throws -> LLMModel { switch self { @@ -56,6 +57,10 @@ public enum ModelType: String, Codable { let configuration = try JSONDecoder().decode( Starcoder2Configuration.self, from: Data(contentsOf: configuration)) return Starcoder2Model(configuration) + case .cohere: + let configuration = try JSONDecoder().decode( + CohereConfiguration.self, from: Data(contentsOf: configuration)) + return CohereModel(configuration) } } } diff --git a/Libraries/LLM/Tokenizer.swift b/Libraries/LLM/Tokenizer.swift index bcd6fb1..fcc5764 100644 --- a/Libraries/LLM/Tokenizer.swift +++ b/Libraries/LLM/Tokenizer.swift @@ -28,5 +28,6 @@ public func loadTokenizer(configuration: ModelConfiguration) async throws -> Tok /// overrides for TokenizerModel/knownTokenizers let replacementTokenizers = [ - "Qwen2Tokenizer": "PreTrainedTokenizer" + "Qwen2Tokenizer": "PreTrainedTokenizer", + "CohereTokenizer": "PreTrainedTokenizer", ] diff --git a/mlx-swift-examples.xcodeproj/project.pbxproj b/mlx-swift-examples.xcodeproj/project.pbxproj index 4446023..4477aae 100644 --- a/mlx-swift-examples.xcodeproj/project.pbxproj +++ b/mlx-swift-examples.xcodeproj/project.pbxproj @@ -68,6 +68,7 @@ C3FBCB312B8520F20007E490 /* MLXNN in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB302B8520F20007E490 /* MLXNN */; }; C3FBCB332B8520F20007E490 /* MLXOptimizers in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB322B8520F20007E490 /* MLXOptimizers */; }; C3FBCB352B8520F20007E490 /* MLXRandom in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB342B8520F20007E490 /* MLXRandom */; }; + F24B083A2BAF1A65008C8D19 /* Cohere.swift in Sources */ = {isa = PBXBuildFile; fileRef = F24B08392BAF1A65008C8D19 /* Cohere.swift */; }; /* End PBXBuildFile section */ /* Begin PBXContainerItemProxy section */ @@ -234,6 +235,7 @@ C3C3240C2B6CA792007D2D9A /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Evaluate.swift; sourceTree = ""; }; C3E786AC2B8D4AF50004D037 /* Tokenizer.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Tokenizer.swift; sourceTree = ""; }; + F24B08392BAF1A65008C8D19 /* Cohere.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Cohere.swift; sourceTree = ""; }; /* End PBXFileReference section */ /* Begin PBXFrameworksBuildPhase section */ @@ -383,6 +385,7 @@ C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */, C3E786AC2B8D4AF50004D037 /* Tokenizer.swift */, 52A776172B94B5EE00AA6E80 /* Qwen2.swift */, + F24B08392BAF1A65008C8D19 /* Cohere.swift */, ); path = LLM; sourceTree = ""; @@ -847,6 +850,7 @@ buildActionMask = 2147483647; files = ( C38935E12B869F420037B833 /* LLMModel.swift in Sources */, + F24B083A2BAF1A65008C8D19 /* Cohere.swift in Sources */, C38935E32B86C0FE0037B833 /* Gemma.swift in Sources */, C38935CD2B869C870037B833 /* Configuration.swift in Sources */, 525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */, diff --git a/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme b/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme index f730513..5dcb893 100644 --- a/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme +++ b/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme @@ -56,9 +56,13 @@ isEnabled = "NO"> + +