From 63531bc5fa8033999cfae923c89af003c29531b9 Mon Sep 17 00:00:00 2001
From: Anchen
Date: Wed, 27 Mar 2024 06:43:10 +1100
Subject: [PATCH] feat: add command r model support (#35)
* feat: add command r model support
---
Libraries/LLM/Cohere.swift | 238 ++++++++++++++++++
Libraries/LLM/Configuration.swift | 5 +
Libraries/LLM/Tokenizer.swift | 3 +-
mlx-swift-examples.xcodeproj/project.pbxproj | 4 +
.../xcshareddata/xcschemes/llm-tool.xcscheme | 6 +-
5 files changed, 254 insertions(+), 2 deletions(-)
create mode 100644 Libraries/LLM/Cohere.swift
diff --git a/Libraries/LLM/Cohere.swift b/Libraries/LLM/Cohere.swift
new file mode 100644
index 0000000..8ccbfbe
--- /dev/null
+++ b/Libraries/LLM/Cohere.swift
@@ -0,0 +1,238 @@
+import Foundation
+import MLX
+import MLXFast
+import MLXNN
+
+// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/cohere.py
+
+private class Attention: Module {
+
+ let args: CohereConfiguration
+ let scale: Float
+
+ @ModuleInfo(key: "q_proj") var wq: Linear
+ @ModuleInfo(key: "k_proj") var wk: Linear
+ @ModuleInfo(key: "v_proj") var wv: Linear
+ @ModuleInfo(key: "o_proj") var wo: Linear
+
+ let rope: RoPE
+
+ public init(_ args: CohereConfiguration) {
+ self.args = args
+
+ let dim = args.hiddenSize
+ let heads = args.attentionHeads
+ let kvHeads = args.kvHeads
+
+ let headDim = args.hiddenSize / heads
+ self.scale = pow(Float(headDim), -0.5)
+
+ self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false)
+ self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
+ self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
+ self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false)
+
+ self.rope = RoPE(
+ dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta)
+ }
+
+ public func callAsFunction(
+ _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
+ ) -> (MLXArray, (MLXArray, MLXArray)) {
+ let (B, L) = (x.dim(0), x.dim(1))
+
+ var queries = wq(x)
+ var keys = wk(x)
+ var values = wv(x)
+
+ // prepare the queries, keys and values for the attention computation
+ queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
+ keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
+ values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
+
+ if let (keyCache, valueCache) = cache {
+ queries = rope(queries, offset: keyCache.dim(2))
+ keys = rope(keys, offset: keyCache.dim(2))
+ keys = concatenated([keyCache, keys], axis: 2)
+ values = concatenated([valueCache, values], axis: 2)
+ } else {
+ queries = rope(queries)
+ keys = rope(keys)
+ }
+
+ let output = MLXFast.scaledDotProductAttention(
+ queries: queries, keys: keys, values: values, scale: scale, mask: mask
+ )
+ .transposed(0, 2, 1, 3)
+ .reshaped(B, L, -1)
+
+ return (wo(output), (keys, values))
+ }
+}
+
+private class MLP: Module, UnaryLayer {
+
+ @ModuleInfo(key: "gate_proj") var gate: Linear
+ @ModuleInfo(key: "down_proj") var down: Linear
+ @ModuleInfo(key: "up_proj") var up: Linear
+
+ public init(dimensions: Int, hiddenDimensions: Int) {
+ self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
+ self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
+ self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false)
+
+ }
+
+ public func callAsFunction(_ x: MLXArray) -> MLXArray {
+ down(silu(gate(x)) * up(x))
+ }
+}
+
+private class TransformerBlock: Module {
+
+ @ModuleInfo(key: "self_attn") var attention: Attention
+ let mlp: MLP
+
+ @ModuleInfo(key: "input_layernorm") var inputLayerNorm: LayerNorm
+
+ public init(_ args: CohereConfiguration) {
+ self._attention.wrappedValue = Attention(args)
+ self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
+ self._inputLayerNorm.wrappedValue = LayerNorm(
+ dimensions: args.hiddenSize, eps: args.layerNormEps)
+
+ }
+
+ public func callAsFunction(
+ _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
+ ) -> (MLXArray, (MLXArray, MLXArray)) {
+ let h = inputLayerNorm(x)
+ let (attnH, cache) = attention(h, mask: mask, cache: cache)
+ let ffH = mlp(h)
+ return (attnH + ffH + x, cache)
+ }
+}
+
+public class CohereModelInner: Module {
+
+ @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
+
+ fileprivate let layers: [TransformerBlock]
+ let norm: LayerNorm
+
+ public init(_ args: CohereConfiguration) {
+ precondition(args.vocabularySize > 0)
+
+ self._embedTokens.wrappedValue = Embedding(
+ embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
+
+ self.layers = (0 ..< args.hiddenLayers)
+ .map { _ in
+ TransformerBlock(args)
+ }
+ self.norm = LayerNorm(dimensions: args.hiddenSize, eps: args.layerNormEps)
+ }
+
+ public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
+ MLXArray, [(MLXArray, MLXArray)]
+ ) {
+ var h = embedTokens(inputs)
+
+ var mask: MLXArray? = nil
+ if h.dim(1) > 1 {
+ mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1))
+ mask = mask?.asType(h.dtype)
+ }
+
+ var newCache = [(MLXArray, MLXArray)]()
+
+ for (i, layer) in layers.enumerated() {
+ var cacheUpdate: (MLXArray, MLXArray)
+ (h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i])
+ newCache.append(cacheUpdate)
+ }
+
+ return (norm(h), newCache)
+ }
+}
+
+public class CohereModel: Module, LLMModel {
+
+ public let vocabularySize: Int
+ let model: CohereModelInner
+ let logitScale: Float
+
+ public init(_ args: CohereConfiguration) {
+ self.vocabularySize = args.vocabularySize
+ self.model = CohereModelInner(args)
+ self.logitScale = args.logitScale
+ }
+
+ public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
+ MLXArray, [(MLXArray, MLXArray)]
+ ) {
+ var (out, cache) = model(inputs, cache: cache)
+ out = matmul(out, model.embedTokens.weight.T)
+ out = out * self.logitScale
+ return (out, cache)
+ }
+}
+
+public struct CohereConfiguration: Codable {
+
+ var hiddenSize: Int
+ var hiddenLayers: Int
+ var intermediateSize: Int
+ var attentionHeads: Int
+ var layerNormEps: Float
+ var vocabularySize: Int
+ var kvHeads: Int
+ var ropeTheta: Float = 8000000.0
+ var ropeTraditional: Bool = true
+ var ropeScaling: [String: StringOrNumber]? = nil
+ var logitScale: Float
+
+ enum CodingKeys: String, CodingKey {
+ case hiddenSize = "hidden_size"
+ case hiddenLayers = "num_hidden_layers"
+ case intermediateSize = "intermediate_size"
+ case attentionHeads = "num_attention_heads"
+ case kvHeads = "num_key_value_heads"
+ case ropeTheta = "rope_theta"
+ case vocabularySize = "vocab_size"
+ case layerNormEps = "layer_norm_eps"
+ case logitScale = "logit_scale"
+ case ropeTraditional = "rope_traditional"
+ case ropeScaling = "rope_scaling"
+ }
+
+ public init(from decoder: Decoder) throws {
+ // custom implementation to handle optional keys with required values
+ let container: KeyedDecodingContainer =
+ try decoder.container(
+ keyedBy: CohereConfiguration.CodingKeys.self)
+
+ self.hiddenSize = try container.decode(
+ Int.self, forKey: CohereConfiguration.CodingKeys.hiddenSize)
+ self.hiddenLayers = try container.decode(
+ Int.self, forKey: CohereConfiguration.CodingKeys.hiddenLayers)
+ self.intermediateSize = try container.decode(
+ Int.self, forKey: CohereConfiguration.CodingKeys.intermediateSize)
+ self.attentionHeads = try container.decode(
+ Int.self, forKey: CohereConfiguration.CodingKeys.attentionHeads)
+ self.layerNormEps = try container.decode(
+ Float.self, forKey: CohereConfiguration.CodingKeys.layerNormEps)
+ self.vocabularySize = try container.decode(
+ Int.self, forKey: CohereConfiguration.CodingKeys.vocabularySize)
+ self.kvHeads = try container.decode(
+ Int.self, forKey: CohereConfiguration.CodingKeys.kvHeads)
+ self.ropeTheta =
+ try container.decodeIfPresent(
+ Float.self, forKey: CohereConfiguration.CodingKeys.ropeTheta)
+ ?? 8000000.0
+ self.ropeScaling = try container.decodeIfPresent(
+ [String: StringOrNumber].self, forKey: CohereConfiguration.CodingKeys.ropeScaling)
+ self.logitScale = try container.decode(
+ Float.self, forKey: CohereConfiguration.CodingKeys.logitScale)
+ }
+}
diff --git a/Libraries/LLM/Configuration.swift b/Libraries/LLM/Configuration.swift
index f207258..d1fdedc 100644
--- a/Libraries/LLM/Configuration.swift
+++ b/Libraries/LLM/Configuration.swift
@@ -33,6 +33,7 @@ public enum ModelType: String, Codable {
case gemma
case qwen2
case starcoder2
+ case cohere
func createModel(configuration: URL) throws -> LLMModel {
switch self {
@@ -56,6 +57,10 @@ public enum ModelType: String, Codable {
let configuration = try JSONDecoder().decode(
Starcoder2Configuration.self, from: Data(contentsOf: configuration))
return Starcoder2Model(configuration)
+ case .cohere:
+ let configuration = try JSONDecoder().decode(
+ CohereConfiguration.self, from: Data(contentsOf: configuration))
+ return CohereModel(configuration)
}
}
}
diff --git a/Libraries/LLM/Tokenizer.swift b/Libraries/LLM/Tokenizer.swift
index bcd6fb1..fcc5764 100644
--- a/Libraries/LLM/Tokenizer.swift
+++ b/Libraries/LLM/Tokenizer.swift
@@ -28,5 +28,6 @@ public func loadTokenizer(configuration: ModelConfiguration) async throws -> Tok
/// overrides for TokenizerModel/knownTokenizers
let replacementTokenizers = [
- "Qwen2Tokenizer": "PreTrainedTokenizer"
+ "Qwen2Tokenizer": "PreTrainedTokenizer",
+ "CohereTokenizer": "PreTrainedTokenizer",
]
diff --git a/mlx-swift-examples.xcodeproj/project.pbxproj b/mlx-swift-examples.xcodeproj/project.pbxproj
index 4446023..4477aae 100644
--- a/mlx-swift-examples.xcodeproj/project.pbxproj
+++ b/mlx-swift-examples.xcodeproj/project.pbxproj
@@ -68,6 +68,7 @@
C3FBCB312B8520F20007E490 /* MLXNN in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB302B8520F20007E490 /* MLXNN */; };
C3FBCB332B8520F20007E490 /* MLXOptimizers in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB322B8520F20007E490 /* MLXOptimizers */; };
C3FBCB352B8520F20007E490 /* MLXRandom in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB342B8520F20007E490 /* MLXRandom */; };
+ F24B083A2BAF1A65008C8D19 /* Cohere.swift in Sources */ = {isa = PBXBuildFile; fileRef = F24B08392BAF1A65008C8D19 /* Cohere.swift */; };
/* End PBXBuildFile section */
/* Begin PBXContainerItemProxy section */
@@ -234,6 +235,7 @@
C3C3240C2B6CA792007D2D9A /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; };
C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Evaluate.swift; sourceTree = ""; };
C3E786AC2B8D4AF50004D037 /* Tokenizer.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Tokenizer.swift; sourceTree = ""; };
+ F24B08392BAF1A65008C8D19 /* Cohere.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Cohere.swift; sourceTree = ""; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
@@ -383,6 +385,7 @@
C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */,
C3E786AC2B8D4AF50004D037 /* Tokenizer.swift */,
52A776172B94B5EE00AA6E80 /* Qwen2.swift */,
+ F24B08392BAF1A65008C8D19 /* Cohere.swift */,
);
path = LLM;
sourceTree = "";
@@ -847,6 +850,7 @@
buildActionMask = 2147483647;
files = (
C38935E12B869F420037B833 /* LLMModel.swift in Sources */,
+ F24B083A2BAF1A65008C8D19 /* Cohere.swift in Sources */,
C38935E32B86C0FE0037B833 /* Gemma.swift in Sources */,
C38935CD2B869C870037B833 /* Configuration.swift in Sources */,
525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */,
diff --git a/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme b/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme
index f730513..5dcb893 100644
--- a/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme
+++ b/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme
@@ -56,9 +56,13 @@
isEnabled = "NO">
+
+