From 66d92023607dd0980da73b2e17e9f9fc2d766e90 Mon Sep 17 00:00:00 2001 From: John Mai Date: Sun, 3 Mar 2024 22:26:28 +0800 Subject: [PATCH] feat: Qwen2 support --- Libraries/LLM/Configuration.swift | 5 + Libraries/LLM/Models.swift | 8 + Libraries/LLM/Qwen2.swift | 263 ++++++++++++++++++ Libraries/LLM/Tokenizer.swift | 1 + Tools/llm-tool/LLMTool.swift | 6 +- mlx-swift-examples.xcodeproj/project.pbxproj | 4 + .../xcshareddata/xcschemes/llm-tool.xcscheme | 6 +- 7 files changed, 291 insertions(+), 2 deletions(-) create mode 100644 Libraries/LLM/Qwen2.swift diff --git a/Libraries/LLM/Configuration.swift b/Libraries/LLM/Configuration.swift index bae4f3f..2a08aa9 100644 --- a/Libraries/LLM/Configuration.swift +++ b/Libraries/LLM/Configuration.swift @@ -31,6 +31,7 @@ public enum ModelType: String, Codable { case llama case phi case gemma + case qwen2 func createModel(configuration: URL) throws -> LLMModel { switch self { @@ -46,6 +47,10 @@ public enum ModelType: String, Codable { let configuration = try JSONDecoder().decode( GemmaConfiguration.self, from: Data(contentsOf: configuration)) return GemmaModel(configuration) + case .qwen2: + let configuration = try JSONDecoder().decode( + Qwen2Configuration.self, from: Data(contentsOf: configuration)) + return Qwen2Model(configuration) } } } diff --git a/Libraries/LLM/Models.swift b/Libraries/LLM/Models.swift index 5564c4a..309cf85 100644 --- a/Libraries/LLM/Models.swift +++ b/Libraries/LLM/Models.swift @@ -84,6 +84,13 @@ extension ModelConfiguration { ) { prompt in "user \(prompt)model" } + + public static let qwen205b4bit = ModelConfiguration( + id: "mlx-community/Qwen1.5-0.5B-Chat-4bit", + overrideTokenizer: "PreTrainedTokenizer" + ) { prompt in + "<|im_start|>user \(prompt)<|im_end|><|im_start|>assistant" + } private enum BootstrapState { case idle @@ -102,6 +109,7 @@ extension ModelConfiguration { codeLlama13b4bit, phi4bit, gemma2bQuantized, + qwen205b4bit, ]) bootstrapState = .bootstrapped diff --git a/Libraries/LLM/Qwen2.swift b/Libraries/LLM/Qwen2.swift new file mode 100644 index 0000000..82bfb83 --- /dev/null +++ b/Libraries/LLM/Qwen2.swift @@ -0,0 +1,263 @@ +// +// Qwen2.swift +// LLM +// +// Created by John Mai on 2024/3/3. +// + +import Foundation +import MLX +import MLXNN + +// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/qwen2.py + +private class Attention: Module { + let args: Qwen2Configuration + let repeats: Int + let scale: Float + + @ModuleInfo(key: "q_proj") var wq: Linear + @ModuleInfo(key: "k_proj") var wk: Linear + @ModuleInfo(key: "v_proj") var wv: Linear + @ModuleInfo(key: "o_proj") var wo: Linear + + let rope: RoPE + + public init(_ args: Qwen2Configuration) { + self.args = args + + let dim = args.hiddenSize + let heads = args.attentionHeads + let kvHeads = args.kvHeads + + self.repeats = heads / kvHeads + + let headDim = args.hiddenSize / heads + self.scale = pow(Float(headDim), -0.5) + + _wq.wrappedValue = Linear(dim, heads * headDim, bias: true) + _wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: true) + _wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: true) + _wo.wrappedValue = Linear(heads * headDim, dim, bias: false) + + let ropeScale: Float + if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"), + let factor = ropeScaling["factor"] + { + switch factor { + case .string: + fatalError("ropeScaling.factor must be a float") + case .float(let v): + ropeScale = 1 / v + } + } else { + ropeScale = 1 + } + + self.rope = RoPE( + dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta, + scale: ropeScale) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil) -> (MLXArray, (MLXArray, MLXArray)) + { + let (B, L) = (x.dim(0), x.dim(1)) + + var queries = wq(x) + var keys = wk(x) + var values = wv(x) + + // prepare the queries, keys and values for the attention computation + queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3) + keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) + values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) + + if repeats > 1 { + keys = MLXArray.repeat(keys, count: repeats, axis: 1) + values = MLXArray.repeat(values, count: repeats, axis: 1) + } + + if let (keyCache, valueCache) = cache { + queries = rope(queries, offset: keyCache.dim(2)) + keys = rope(keys, offset: keyCache.dim(2)) + keys = concatenated([keyCache, keys], axis: 2) + values = concatenated([valueCache, values], axis: 2) + } else { + queries = rope(queries) + keys = rope(keys) + } + + var scores = (queries * scale).matmul(keys.transposed(0, 1, 3, 2)) + if let mask { + scores = scores + mask + } + + scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype) + + let output = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1) + + return (wo(output), (keys, values)) + } +} + +private class MLP: Module, UnaryLayer { + @ModuleInfo(key: "gate_proj") var gate: Linear + @ModuleInfo(key: "down_proj") var down: Linear + @ModuleInfo(key: "up_proj") var up: Linear + + public init(dimensions: Int, hiddenDimensions: Int) { + _gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + _down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false) + _up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + down(silu(gate(x)) * up(x)) + } +} + +private class TransformerBlock: Module { + @ModuleInfo(key: "self_attn") var attention: Attention + let mlp: MLP + + @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm + @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm + + public init(_ args: Qwen2Configuration) { + _attention.wrappedValue = Attention(args) + self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize) + _inputLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + _postAttentionLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil) -> (MLXArray, (MLXArray, MLXArray)) + { + var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache) + let h = x + r + r = mlp(postAttentionLayerNorm(h)) + let out = h + r + return (out, cache) + } +} + +public class Qwen2ModelInner: Module { + @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + + fileprivate let layers: [TransformerBlock] + let norm: RMSNorm + + public init(_ args: Qwen2Configuration) { + precondition(args.vocabularySize > 0) + + _embedTokens.wrappedValue = Embedding( + embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) + + self.layers = (0 ..< args.hiddenLayers) + .map { _ in + TransformerBlock(args) + } + self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> ( + MLXArray, [(MLXArray, MLXArray)]) + { + var h = embedTokens(inputs) + + var mask: MLXArray? = nil + if h.dim(1) > 1 { + mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1)) + mask = mask?.asType(h.dtype) + } + + var newCache = [(MLXArray, MLXArray)]() + + for (i, layer) in layers.enumerated() { + var cacheUpdate: (MLXArray, MLXArray) + (h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i]) + newCache.append(cacheUpdate) + } + + return (norm(h), newCache) + } +} + +public class Qwen2Model: Module, LLMModel { + public let vocabularySize: Int + let model: Qwen2ModelInner + + @ModuleInfo(key: "lm_head") var lmHead: Linear + + public init(_ args: Qwen2Configuration) { + self.vocabularySize = args.vocabularySize + self.model = Qwen2ModelInner(args) + _lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> ( + MLXArray, [(MLXArray, MLXArray)]) + { + let (out, cache) = model(inputs, cache: cache) + return (lmHead(out), cache) + } +} + +public struct Qwen2Configuration: Codable { + var hiddenSize: Int + var hiddenLayers: Int + var intermediateSize: Int + var attentionHeads: Int + var rmsNormEps: Float + var vocabularySize: Int + var kvHeads: Int + var ropeTheta: Float = 1_000_000 + var ropeTraditional: Bool = false + var ropeScaling: [String: StringOrNumber]? = nil + + enum CodingKeys: String, CodingKey { + case hiddenSize = "hidden_size" + case hiddenLayers = "num_hidden_layers" + case intermediateSize = "intermediate_size" + case attentionHeads = "num_attention_heads" + case rmsNormEps = "rms_norm_eps" + case vocabularySize = "vocab_size" + case kvHeads = "num_key_value_heads" + case ropeTheta = "rope_theta" + case ropeTraditional = "rope_traditional" + case ropeScaling = "rope_scaling" + } + + public init(from decoder: Decoder) throws { + // custom implementation to handle optional keys with required values + let container: KeyedDecodingContainer = + try decoder.container( + keyedBy: Qwen2Configuration.CodingKeys.self) + + self.hiddenSize = try container.decode( + Int.self, forKey: Qwen2Configuration.CodingKeys.hiddenSize) + self.hiddenLayers = try container.decode( + Int.self, forKey: Qwen2Configuration.CodingKeys.hiddenLayers) + self.intermediateSize = try container.decode( + Int.self, forKey: Qwen2Configuration.CodingKeys.intermediateSize) + self.attentionHeads = try container.decode( + Int.self, forKey: Qwen2Configuration.CodingKeys.attentionHeads) + self.rmsNormEps = try container.decode( + Float.self, forKey: Qwen2Configuration.CodingKeys.rmsNormEps) + self.vocabularySize = try container.decode( + Int.self, forKey: Qwen2Configuration.CodingKeys.vocabularySize) + self.kvHeads = try container.decode(Int.self, forKey: Qwen2Configuration.CodingKeys.kvHeads) + self.ropeTheta = + try container.decodeIfPresent( + Float.self, forKey: Qwen2Configuration.CodingKeys.ropeTheta) + ?? 1_000_000 + self.ropeTraditional = + try container.decodeIfPresent( + Bool.self, forKey: Qwen2Configuration.CodingKeys.ropeTraditional) ?? false + self.ropeScaling = try container.decodeIfPresent( + [String: StringOrNumber].self, forKey: Qwen2Configuration.CodingKeys.ropeScaling) + } +} diff --git a/Libraries/LLM/Tokenizer.swift b/Libraries/LLM/Tokenizer.swift index ce8b291..0159bec 100644 --- a/Libraries/LLM/Tokenizer.swift +++ b/Libraries/LLM/Tokenizer.swift @@ -116,4 +116,5 @@ public func discardUnhandledMerges(tokenizerData: Config) -> Config { let replacementTokenizers = [ "CodeLlamaTokenizer": "LlamaTokenizer", "GemmaTokenizer": "PreTrainedTokenizer", + "Qwen2Tokenizer": "PreTrainedTokenizer", ] diff --git a/Tools/llm-tool/LLMTool.swift b/Tools/llm-tool/LLMTool.swift index 4406025..e0035e2 100644 --- a/Tools/llm-tool/LLMTool.swift +++ b/Tools/llm-tool/LLMTool.swift @@ -42,7 +42,9 @@ struct SyncGenerator: AsyncParsableCommand { let modelConfiguration = ModelConfiguration.configuration(id: model) let (model, tokenizer) = try await load(configuration: modelConfiguration) - + + print("Model loaded -> \(self.model)") + let prompt = modelConfiguration.prepare(prompt: self.prompt) let promptTokens = tokenizer.encode(text: prompt) @@ -131,6 +133,8 @@ struct AsyncGenerator: AsyncParsableCommand { let modelConfiguration = ModelConfiguration.configuration(id: model) let (model, tokenizer) = try await load(configuration: modelConfiguration) + + print("Model loaded -> \(self.model)") let prompt = modelConfiguration.prepare(prompt: self.prompt) let promptTokens = tokenizer.encode(text: prompt) diff --git a/mlx-swift-examples.xcodeproj/project.pbxproj b/mlx-swift-examples.xcodeproj/project.pbxproj index 3ee0392..0623fb2 100644 --- a/mlx-swift-examples.xcodeproj/project.pbxproj +++ b/mlx-swift-examples.xcodeproj/project.pbxproj @@ -7,6 +7,7 @@ objects = { /* Begin PBXBuildFile section */ + 52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 52A776172B94B5EE00AA6E80 /* Qwen2.swift */; }; C3288D762B6D9313009FF608 /* LinearModelTraining.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3288D752B6D9313009FF608 /* LinearModelTraining.swift */; }; C3288D7B2B6D9339009FF608 /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C3288D7A2B6D9339009FF608 /* ArgumentParser */; }; C34E48F52B696F0B00FCB841 /* LLMTool.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E48F42B696F0B00FCB841 /* LLMTool.swift */; }; @@ -180,6 +181,7 @@ /* End PBXCopyFilesBuildPhase section */ /* Begin PBXFileReference section */ + 52A776172B94B5EE00AA6E80 /* Qwen2.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Qwen2.swift; sourceTree = ""; }; C325DE3F2B648CDB00628871 /* README.md */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; C3288D732B6D9313009FF608 /* LinearModelTraining */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = LinearModelTraining; sourceTree = BUILT_PRODUCTS_DIR; }; C3288D752B6D9313009FF608 /* LinearModelTraining.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LinearModelTraining.swift; sourceTree = ""; }; @@ -363,6 +365,7 @@ C34E48ED2B696E6500FCB841 /* Load.swift */, C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */, C3E786AC2B8D4AF50004D037 /* Tokenizer.swift */, + 52A776172B94B5EE00AA6E80 /* Qwen2.swift */, ); path = LLM; sourceTree = ""; @@ -829,6 +832,7 @@ C3A8B3AC2B9283150002EFB8 /* Models.swift in Sources */, C3E786AB2B8D1AEC0004D037 /* Evaluate.swift in Sources */, C38935CC2B869C870037B833 /* Llama.swift in Sources */, + 52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */, ); runOnlyForDeploymentPostprocessing = 0; }; diff --git a/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme b/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme index 0200f91..b472939 100644 --- a/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme +++ b/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme @@ -55,6 +55,10 @@ argument = "--model mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX" isEnabled = "NO"> + + @@ -69,7 +73,7 @@ + isEnabled = "NO">