From 4d20785b5d462e9e0eb6e8647966a2ecf3abc2ae Mon Sep 17 00:00:00 2001 From: Sachin Desai Date: Tue, 30 Apr 2024 09:14:27 -0700 Subject: [PATCH] add support for OpenELM (#63) * add support for OpenELM * register model configuration for bootstrap --- Libraries/LLM/Configuration.swift | 5 + Libraries/LLM/Models.swift | 7 + Libraries/LLM/OpenELM.swift | 316 +++++++++++++++++++ mlx-swift-examples.xcodeproj/project.pbxproj | 4 + 4 files changed, 332 insertions(+) create mode 100644 Libraries/LLM/OpenELM.swift diff --git a/Libraries/LLM/Configuration.swift b/Libraries/LLM/Configuration.swift index c4fb906..84ee410 100644 --- a/Libraries/LLM/Configuration.swift +++ b/Libraries/LLM/Configuration.swift @@ -35,6 +35,7 @@ public enum ModelType: String, Codable { case qwen2 case starcoder2 case cohere + case openelm public func createModel(configuration: URL) throws -> LLMModel { switch self { @@ -66,6 +67,10 @@ public enum ModelType: String, Codable { let configuration = try JSONDecoder().decode( CohereConfiguration.self, from: Data(contentsOf: configuration)) return CohereModel(configuration) + case .openelm: + let configuration = try JSONDecoder().decode( + OpenElmConfiguration.self, from: Data(contentsOf: configuration)) + return OpenELMModel(configuration) } } } diff --git a/Libraries/LLM/Models.swift b/Libraries/LLM/Models.swift index 2fe52d7..917d59f 100644 --- a/Libraries/LLM/Models.swift +++ b/Libraries/LLM/Models.swift @@ -137,6 +137,12 @@ extension ModelConfiguration { "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant" } + public static let openelm270m4bit = ModelConfiguration( + id: "mlx-community/OpenELM-270M-Instruct" + ) { prompt in + "\(prompt)" + } + private enum BootstrapState { case idle case bootstrapping @@ -156,6 +162,7 @@ extension ModelConfiguration { phi34bit, gemma2bQuantized, qwen205b4bit, + openelm270m4bit, ]) bootstrapState = .bootstrapped diff --git a/Libraries/LLM/OpenELM.swift b/Libraries/LLM/OpenELM.swift new file mode 100644 index 0000000..13c98b3 --- /dev/null +++ b/Libraries/LLM/OpenELM.swift @@ -0,0 +1,316 @@ +// +// OpenELM.swift +// LLM +// +// Created by Sachin Desai on 2024/4/27. +// + +import Foundation +import MLX +import MLXFast +import MLXNN + +func computeHeads(modelDim: Int, headDim: Int) -> Int { + assert(modelDim % headDim == 0, "modelDim must be divisible by headDim") + return modelDim / headDim +} + +func makeDivisible(_ v: Float, divisor: Int = 8, minValue: Float? = nil) -> Int { + let minVal = minValue ?? Float(divisor) + var roundDown = max(minVal, Float(Int((v + Float(divisor) / 2) / Float(divisor)) * divisor)) + + if roundDown < 0.9 * v { + roundDown += Float(divisor) + } + return Int(roundDown) +} + +private class MultiHeadCausalAttention: Module { + var args: OpenElmConfiguration + let scale: Float + let heads: Int + let headDim: Int + let kvHeads: Int + + @ModuleInfo(key: "qkv_proj") var qkvProj: Linear + @ModuleInfo(key: "out_proj") var outProj: Linear + + @ModuleInfo(key: "q_norm") var qNorm: RMSNorm + @ModuleInfo(key: "k_norm") var kNorm: RMSNorm + + let rope: RoPE + + public init(_ args: OpenElmConfiguration, layerId: Int) { + self.args = args + self.headDim = args.headDimensions + let modelDim = args.modelDim + + self.heads = self.args.numQueryHeads[layerId] + self.kvHeads = self.args.kvHeads[layerId] + self.scale = pow(Float(headDim), -0.5) + + let opSize = (heads + (kvHeads * 2)) * headDim + self._qkvProj.wrappedValue = Linear(modelDim, opSize, bias: false) + self._outProj.wrappedValue = Linear(heads * headDim, modelDim, bias: false) + + if args.normalizeQkProjections { + self._qNorm.wrappedValue = RMSNorm(dimensions: headDim, eps: args.rmsNormEps) + self._kNorm.wrappedValue = RMSNorm(dimensions: headDim, eps: args.rmsNormEps) + } + + self.rope = RoPE( + dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil + ) -> (MLXArray, (MLXArray, MLXArray)) { + let (B, L) = (x.dim(0), x.dim(1)) + let qkv = qkvProj(x).reshaped(B, L, heads + (kvHeads * 2), headDim).transposed(0, 2, 1, 3) + + let qkvSplit = split(qkv, indices: [heads, heads + kvHeads], axis: 1) + var queries = qkvSplit[0] + var keys = qkvSplit[1] + var values = qkvSplit[2] + + if args.normalizeQkProjections { + queries = qNorm(queries) + keys = kNorm(keys) + } + + if let (keyCache, valueCache) = cache { + queries = rope(queries, offset: keyCache.dim(2)) + keys = rope(keys, offset: keyCache.dim(2)) + keys = concatenated([keyCache, keys], axis: 2) + values = concatenated([valueCache, values], axis: 2) + } else { + queries = rope(queries) + keys = rope(keys) + } + + let output = MLXFast.scaledDotProductAttention( + queries: queries, keys: keys, values: values, scale: scale, mask: mask + ).transposed(0, 2, 1, 3).reshaped(B, L, heads * headDim) + + return (outProj(output), (keys, values)) + } +} + +private class FeedForwardNetwork: Module, UnaryLayer { + @ModuleInfo var proj_1: Linear + @ModuleInfo var proj_2: Linear + + public init(_ args: OpenElmConfiguration, layedId: Int) { + let dim = args.modelDim + let ffnMultiplier = args.ffnMultipliers[layedId] + let intermediateDim = Int( + makeDivisible(Float(ffnMultiplier) * Float(dim), divisor: args.ffnDimDivisor)) + + self.proj_1 = Linear(dim, 2 * intermediateDim) + self.proj_2 = Linear(intermediateDim, dim) + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + let a = proj_1(x) + let b = split(a, parts: 2, axis: -1) + let gate = b[0] + let x = b[1] + return proj_2(silu(gate) * x) + } +} + +private class TransformerDecoderLayer: Module { + @ModuleInfo(key: "attn") var attn: MultiHeadCausalAttention + let ffn: FeedForwardNetwork + + @ModuleInfo(key: "ffn_norm") var ffnNorm: RMSNorm + @ModuleInfo(key: "attn_norm") var attnNorm: RMSNorm + + public init(_ args: OpenElmConfiguration, layerId: Int) { + let dim = args.modelDim + self._attn.wrappedValue = MultiHeadCausalAttention(args, layerId: layerId) + self.ffn = FeedForwardNetwork(args, layedId: layerId) + self._ffnNorm.wrappedValue = RMSNorm(dimensions: dim, eps: args.rmsNormEps) + self._attnNorm.wrappedValue = RMSNorm(dimensions: dim, eps: args.rmsNormEps) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil + ) -> (MLXArray, (MLXArray, MLXArray)) { + var (r, cache) = attn(attnNorm(x), mask: mask, cache: cache) + let h = x + r + r = ffn(ffnNorm(h)) + let out = h + r + return (out, cache) + } +} + +class OpenELMModelInner: Module, LLMModel { + var vocabularySize: Int + + @ModuleInfo(key: "token_embeddings") var embedTokens: Embedding + + fileprivate let layers: [TransformerDecoderLayer] + fileprivate let norm: RMSNorm + + public init(_ args: OpenElmConfiguration) { + precondition(args.vocabularySize > 0) + + self.vocabularySize = args.vocabularySize + self._embedTokens.wrappedValue = Embedding( + embeddingCount: self.vocabularySize, dimensions: args.modelDim) + + self.layers = (0 ..< args.numTransformerLayers) + .map { layerId in + TransformerDecoderLayer(args, layerId: layerId) + } + + self.norm = RMSNorm(dimensions: args.modelDim, eps: args.rmsNormEps) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> ( + MLXArray, [(MLXArray, MLXArray)] + ) { + var h = embedTokens(inputs) + var mask: MLXArray? = nil + if h.dim(1) > 1 { + mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1)) + mask = mask?.asType(h.dtype) + } + + var newCache = [(MLXArray, MLXArray)]() + for (i, layer) in layers.enumerated() { + var cacheUpdate: (MLXArray, MLXArray) + (h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i]) + newCache.append(cacheUpdate) + } + + return (norm(h), newCache) + } +} + +public class OpenELMModel: Module, LLMModel { + public let vocabularySize: Int + let shareInputOutputLayers: Bool + let transformer: OpenELMModelInner + + @ModuleInfo(key: "lm_head") var lmHead: Linear + + public init(_ args: OpenElmConfiguration) { + self.vocabularySize = args.vocabularySize + self.transformer = OpenELMModelInner(args) + self.shareInputOutputLayers = args.shareInputOutputLayers + self._lmHead.wrappedValue = Linear( + args.numTransformerLayers, args.vocabularySize, bias: false) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> ( + MLXArray, [(MLXArray, MLXArray)] + ) { + var (out, cache) = transformer(inputs, cache: cache) + if shareInputOutputLayers { + out = matmul(out, transformer.embedTokens.weight.T) + } else { + out = lmHead(out) + } + + return (out, cache) + } +} + +public struct OpenElmConfiguration: Codable { + var modelType: String + var headDimensions: Int + var numTransformerLayers: Int + var modelDim: Int + var vocabularySize: Int + var ffnDimDivisor: Int + var numQueryHeads: [Int] = [] + var kvHeads: [Int] = [] + var ffnWithGlu: Bool = true + var normalizeQkProjections: Bool = true + var shareInputOutputLayers: Bool = true + var rmsNormEps: Float = 1e-6 + var ropeTheta: Float = 10_000 + var ropeTraditional: Bool = false + var numGqaGroups: Int = 4 + var ffnMultipliers: [Float] = [0.5, 4.0] + var qkvMultiplier: [Float] = [0.5, 1.0] + + enum CodingKeys: String, CodingKey { + case modelType = "model_type" + case headDimensions = "head_dim" + case numTransformerLayers = "num_transformer_layers" + case modelDim = "model_dim" + case vocabularySize = "vocab_size" + case ffnDimDivisor = "ffn_dim_divisor" + case ffnMultipliers = "ffn_multipliers" + case ffnWithGlu = "ffn_with_glu" + case normalizeQkProjections = "normalize_qk_projections" + case shareInputOutputLayers = "share_input_output_layers" + } + + public init(from decoder: Decoder) throws { + // custom implementation to handle optional keys with required values + let container: KeyedDecodingContainer = + try decoder.container( + keyedBy: OpenElmConfiguration.CodingKeys.self) + + self.modelType = try container.decode( + String.self, forKey: OpenElmConfiguration.CodingKeys.modelType) + self.headDimensions = try container.decode( + Int.self, forKey: OpenElmConfiguration.CodingKeys.headDimensions) + self.numTransformerLayers = try container.decode( + Int.self, forKey: OpenElmConfiguration.CodingKeys.numTransformerLayers) + + self.modelDim = try container.decode( + Int.self, forKey: OpenElmConfiguration.CodingKeys.modelDim) + self.vocabularySize = try container.decode( + Int.self, forKey: OpenElmConfiguration.CodingKeys.vocabularySize) + self.ffnDimDivisor = try container.decode( + Int.self, forKey: OpenElmConfiguration.CodingKeys.ffnDimDivisor) + + let qkvMultipliers = stride( + from: qkvMultiplier[0], through: qkvMultiplier[1], + by: (qkvMultiplier[1] - qkvMultiplier[0]) / Float(numTransformerLayers - 1) + ) + .map { round($0 * 100) / 100 } + + let headMultipleOf = numGqaGroups + let queryDims = qkvMultipliers.map { a in + makeDivisible(Float(self.modelDim) * a, divisor: self.headDimensions * headMultipleOf) + } + + self.numQueryHeads = queryDims.map { qDim in + Int(computeHeads(modelDim: qDim, headDim: self.headDimensions)) + } + + self.kvHeads = self.numQueryHeads.map { qHeads in + qHeads / numGqaGroups + } + + self.ffnMultipliers = stride( + from: ffnMultipliers[0], through: ffnMultipliers[1], + by: (ffnMultipliers[1] - ffnMultipliers[0]) / Float(numTransformerLayers - 1) + ) + .map { round($0 * 100) / 100 } + + self.ffnWithGlu = + try container.decodeIfPresent( + Bool.self, forKey: OpenElmConfiguration.CodingKeys.ffnWithGlu) ?? true + self.normalizeQkProjections = + try container.decodeIfPresent( + Bool.self, forKey: OpenElmConfiguration.CodingKeys.normalizeQkProjections) ?? true + self.shareInputOutputLayers = + try container.decodeIfPresent( + Bool.self, forKey: OpenElmConfiguration.CodingKeys.shareInputOutputLayers) ?? true + } +} + +// MARK: - LoRA + +extension OpenELMModel: LoRAModel { + public func loraLinearLayers() -> LoRALinearLayers { + transformer.layers.map { ($0.attn, ["qkv_proj"]) } + } +} diff --git a/mlx-swift-examples.xcodeproj/project.pbxproj b/mlx-swift-examples.xcodeproj/project.pbxproj index 17441d8..c4e7142 100644 --- a/mlx-swift-examples.xcodeproj/project.pbxproj +++ b/mlx-swift-examples.xcodeproj/project.pbxproj @@ -11,6 +11,7 @@ 1CD79C702BD80DE100B6C06F /* Phi3.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1CD79C6F2BD80DE100B6C06F /* Phi3.swift */; }; 525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */; }; 52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 52A776172B94B5EE00AA6E80 /* Qwen2.swift */; }; + 7BBD0D6E2BE044A10019C5D7 /* OpenELM.swift in Sources */ = {isa = PBXBuildFile; fileRef = 7BBD0D6D2BE044A10019C5D7 /* OpenELM.swift */; }; 81695B412BA373D300F260D8 /* MarkdownUI in Frameworks */ = {isa = PBXBuildFile; productRef = 81695B402BA373D300F260D8 /* MarkdownUI */; }; 819BEFF82BAF8B4E0002CCEE /* DeviceStat.swift in Sources */ = {isa = PBXBuildFile; fileRef = 819BEFF62BAF8B4E0002CCEE /* DeviceStat.swift */; }; C3056BAE2BCD97B700A31D04 /* LoRATrainingExampleApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3056BAD2BCD97B700A31D04 /* LoRATrainingExampleApp.swift */; }; @@ -220,6 +221,7 @@ 1CD79C6F2BD80DE100B6C06F /* Phi3.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Phi3.swift; sourceTree = ""; }; 525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Starcoder2.swift; sourceTree = ""; }; 52A776172B94B5EE00AA6E80 /* Qwen2.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Qwen2.swift; sourceTree = ""; }; + 7BBD0D6D2BE044A10019C5D7 /* OpenELM.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = OpenELM.swift; sourceTree = ""; }; 819BEFF62BAF8B4E0002CCEE /* DeviceStat.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = DeviceStat.swift; sourceTree = ""; }; C3056BA12BCD973400A31D04 /* test.jsonl */ = {isa = PBXFileReference; lastKnownFileType = text; path = test.jsonl; sourceTree = ""; }; C3056BA22BCD973400A31D04 /* train.jsonl */ = {isa = PBXFileReference; lastKnownFileType = text; path = train.jsonl; sourceTree = ""; }; @@ -470,6 +472,7 @@ C38935C62B869C7A0037B833 /* LLM */ = { isa = PBXGroup; children = ( + 7BBD0D6D2BE044A10019C5D7 /* OpenELM.swift */, C36BEFAF2BBCBAC2002D4AFE /* Lora.swift */, C36BEFBA2BBF02CC002D4AFE /* Lora+Data.swift */, 525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */, @@ -1006,6 +1009,7 @@ 525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */, C36BEFBB2BBF02CC002D4AFE /* Lora+Data.swift in Sources */, C36BEFB02BBCBAC2002D4AFE /* Lora.swift in Sources */, + 7BBD0D6E2BE044A10019C5D7 /* OpenELM.swift in Sources */, C38935DF2B869DD00037B833 /* Phi.swift in Sources */, C38935CE2B869C870037B833 /* Load.swift in Sources */, C3E786AD2B8D4AF50004D037 /* Tokenizer.swift in Sources */,