diff --git a/Libraries/LLM/Configuration.swift b/Libraries/LLM/Configuration.swift index 84ee410..dd0519e 100644 --- a/Libraries/LLM/Configuration.swift +++ b/Libraries/LLM/Configuration.swift @@ -32,6 +32,7 @@ public enum ModelType: String, Codable { case phi case phi3 case gemma + case gemma2 case qwen2 case starcoder2 case cohere @@ -55,6 +56,10 @@ public enum ModelType: String, Codable { let configuration = try JSONDecoder().decode( GemmaConfiguration.self, from: Data(contentsOf: configuration)) return GemmaModel(configuration) + case .gemma2: + let configuration = try JSONDecoder().decode( + GemmaConfiguration.self, from: Data(contentsOf: configuration)) + return Gemma2Model(configuration) case .qwen2: let configuration = try JSONDecoder().decode( Qwen2Configuration.self, from: Data(contentsOf: configuration)) diff --git a/Libraries/LLM/Gemma.swift b/Libraries/LLM/Gemma.swift index 14a96b1..8d91bb2 100644 --- a/Libraries/LLM/Gemma.swift +++ b/Libraries/LLM/Gemma.swift @@ -262,3 +262,111 @@ extension GemmaModel: LoRAModel { model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } } } + +// Gemma 2 + +// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma2.py + +// Minimal changes from Gemma TransformerBlock +private class Gemma2TransformerBlock: Module { + + @ModuleInfo(key: "self_attn") var attention: Attention + let mlp: MLP + + @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm + @ModuleInfo(key: "pre_feedforward_layernorm") var preFeedforwardLayerNorm: RMSNorm + @ModuleInfo(key: "post_feedforward_layernorm") var postFeedforwardLayerNorm: RMSNorm + @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm + + public init(_ args: GemmaConfiguration) { + self._attention.wrappedValue = Attention(args) + self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize) + self._inputLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + self._preFeedforwardLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + self._postFeedforwardLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + self._postAttentionLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil + ) -> (MLXArray, (MLXArray, MLXArray)) { + var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache) + let h = x + postAttentionLayerNorm(r) + r = mlp(preFeedforwardLayerNorm(h)) + let out = h + postFeedforwardLayerNorm(r) + return (out, cache) + } +} + +// Uses Gemma2TransformerBlock, otherwise same as GemmaModelInner +public class Gemma2ModelInner: Module { + + @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + + fileprivate let layers: [Gemma2TransformerBlock] + fileprivate let norm: RMSNorm + + let hiddenScale: Float + + public init(_ args: GemmaConfiguration) { + precondition(args.vocabularySize > 0) + + self._embedTokens.wrappedValue = Embedding( + embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) + + self.hiddenScale = pow(Float(args.hiddenSize), 0.5) + + self.layers = (0 ..< args.hiddenLayers) + .map { _ in + Gemma2TransformerBlock(args) + } + self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> ( + MLXArray, [(MLXArray, MLXArray)] + ) { + var h = embedTokens(inputs) + h = h * hiddenScale + + var mask: MLXArray? = nil + if h.dim(1) > 1 { + mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1)) + mask = mask?.asType(h.dtype) + } + + var newCache = [(MLXArray, MLXArray)]() + + for (i, layer) in layers.enumerated() { + var cacheUpdate: (MLXArray, MLXArray) + (h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i]) + newCache.append(cacheUpdate) + } + + return (norm(h), newCache) + } +} + +// Uses Gemma2ModelInner, otherwise same as GemmaModel +public class Gemma2Model: Module, LLMModel { + + public let vocabularySize: Int + let model: Gemma2ModelInner + + public init(_ args: GemmaConfiguration) { + self.vocabularySize = args.vocabularySize + self.model = Gemma2ModelInner(args) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> ( + MLXArray, [(MLXArray, MLXArray)] + ) { + var (out, cache) = model(inputs, cache: cache) + out = model.embedTokens.asLinear(out) + return (out, cache) + } +} diff --git a/Libraries/LLM/Models.swift b/Libraries/LLM/Models.swift index 01ab42a..937ee5b 100644 --- a/Libraries/LLM/Models.swift +++ b/Libraries/LLM/Models.swift @@ -157,6 +157,17 @@ extension ModelConfiguration { "user \(prompt)model" } + public static let gemma_2_9b_it_4bit = ModelConfiguration( + id: "mlx-community/gemma-2-9b-it-4bit", + overrideTokenizer: "PreTrainedTokenizer", + + // https://www.promptingguide.ai/models/gemma + defaultPrompt: "What is the difference between lettuce and cabbage?" + + ) { prompt in + "user \(prompt)model" + } + public static let qwen205b4bit = ModelConfiguration( id: "mlx-community/Qwen1.5-0.5B-Chat-4bit", overrideTokenizer: "PreTrainedTokenizer", @@ -200,6 +211,7 @@ extension ModelConfiguration { phi4bit, phi34bit, gemma2bQuantized, + gemma_2_9b_it_4bit, qwen205b4bit, openelm270m4bit, ])