From b951b78eb2357268ed3b52b8cf5de0315b1998c9 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 24 Apr 2024 09:31:01 -0700 Subject: [PATCH] phi3 (#54) * phi3 Co-authored-by: David Koski --- Applications/LLMEval/ContentView.swift | 2 +- Libraries/LLM/Configuration.swift | 5 + Libraries/LLM/Evaluate.swift | 10 +- Libraries/LLM/Models.swift | 8 + Libraries/LLM/Phi3.swift | 257 +++++++++++++++++++ Tools/llm-tool/LLMTool.swift | 4 +- mlx-swift-examples.xcodeproj/project.pbxproj | 6 + 7 files changed, 284 insertions(+), 8 deletions(-) create mode 100644 Libraries/LLM/Phi3.swift diff --git a/Applications/LLMEval/ContentView.swift b/Applications/LLMEval/ContentView.swift index 64e38da..02e4478 100644 --- a/Applications/LLMEval/ContentView.swift +++ b/Applications/LLMEval/ContentView.swift @@ -157,7 +157,7 @@ class LLMEvaluator { /// this controls which model loads -- phi4bit is one of the smaller ones so this will fit on /// more devices - let modelConfiguration = ModelConfiguration.phi4bit + let modelConfiguration = ModelConfiguration.phi34bit /// parameters controlling the output let generateParameters = GenerateParameters(temperature: 0.6) diff --git a/Libraries/LLM/Configuration.swift b/Libraries/LLM/Configuration.swift index f29151f..c4fb906 100644 --- a/Libraries/LLM/Configuration.swift +++ b/Libraries/LLM/Configuration.swift @@ -30,6 +30,7 @@ public enum ModelType: String, Codable { case mistral case llama case phi + case phi3 case gemma case qwen2 case starcoder2 @@ -45,6 +46,10 @@ public enum ModelType: String, Codable { let configuration = try JSONDecoder().decode( PhiConfiguration.self, from: Data(contentsOf: configuration)) return PhiModel(configuration) + case .phi3: + let configuration = try JSONDecoder().decode( + Phi3Configuration.self, from: Data(contentsOf: configuration)) + return Phi3Model(configuration) case .gemma: let configuration = try JSONDecoder().decode( GemmaConfiguration.self, from: Data(contentsOf: configuration)) diff --git a/Libraries/LLM/Evaluate.swift b/Libraries/LLM/Evaluate.swift index 94fdda7..5d870a1 100644 --- a/Libraries/LLM/Evaluate.swift +++ b/Libraries/LLM/Evaluate.swift @@ -60,16 +60,16 @@ public struct GenerateParameters { public var temperature: Float = 0.6 /// top p sampling - public var topP: Float = 0.9 + public var topP: Float = 1.0 /// penalty factor for repeating tokens - public var repetitionPenalty: Float = 1.0 + public var repetitionPenalty: Float? /// number of tokens to consider for repetition penalty public var repetitionContextSize: Int = 20 public init( - temperature: Float = 0.6, topP: Float = 0.9, repetitionPenalty: Float = 1.0, + temperature: Float = 0.6, topP: Float = 1.0, repetitionPenalty: Float? = nil, repetitionContextSize: Int = 20 ) { self.temperature = temperature @@ -111,11 +111,11 @@ public struct TokenIterator: Sequence, IteratorProtocol { var logits: MLXArray (logits, cache) = model(expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache) logits = logits[0..., -1, 0...] - if parameters.repetitionPenalty > 1.0 { + if let repetitionPenalty = parameters.repetitionPenalty { // apply repetition penalty logits = applyRepetitionPenalty( logits: logits, repetitionContext: repetitionContext, - penalty: parameters.repetitionPenalty) + penalty: repetitionPenalty) } y = sample(logits: logits, temp: parameters.temperature, topP: parameters.topP) // append the current token to the context and check repetitionPenalty context see if need to remove the first token diff --git a/Libraries/LLM/Models.swift b/Libraries/LLM/Models.swift index 030d015..2fe52d7 100644 --- a/Libraries/LLM/Models.swift +++ b/Libraries/LLM/Models.swift @@ -116,6 +116,13 @@ extension ModelConfiguration { "Instruct: \(prompt)\nOutput: " } + public static let phi34bit = ModelConfiguration( + id: "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed" + ) { + prompt in + "<|user|>\n\(prompt)<|end|>\n<|assistant|>\n" + } + public static let gemma2bQuantized = ModelConfiguration( id: "mlx-community/quantized-gemma-2b-it", overrideTokenizer: "PreTrainedTokenizer" @@ -146,6 +153,7 @@ extension ModelConfiguration { mistral7B4bit, codeLlama13b4bit, phi4bit, + phi34bit, gemma2bQuantized, qwen205b4bit, ]) diff --git a/Libraries/LLM/Phi3.swift b/Libraries/LLM/Phi3.swift new file mode 100644 index 0000000..d35d4de --- /dev/null +++ b/Libraries/LLM/Phi3.swift @@ -0,0 +1,257 @@ +// Copyright © 2024 Apple Inc. + +import Foundation +import MLX +import MLXFast +import MLXNN + +private class Attention: Module { + + let args: Phi3Configuration + let scale: Float + + @ModuleInfo(key: "qkv_proj") var wqkv: Linear + @ModuleInfo(key: "o_proj") var wo: Linear + + let rope: RoPE + + public init(_ args: Phi3Configuration) { + self.args = args + + let dim = args.hiddenSize + let heads = args.attentionHeads + let kvHeads = args.kvHeads + + let headDim = args.hiddenSize / heads + self.scale = pow(Float(headDim), -0.5) + + self._wqkv.wrappedValue = Linear(dim, (heads + 2 * kvHeads) * headDim, bias: false) + self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false) + + let ropeScale: Float + if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"), + let factor = ropeScaling["factor"] + { + switch factor { + case .string: + fatalError("ropeScaling.factor must be a float") + case .float(let v): + ropeScale = 1 / v + } + } else { + ropeScale = 1 + } + + self.rope = RoPE( + dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta, + scale: ropeScale) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil + ) -> (MLXArray, (MLXArray, MLXArray)) { + let (B, L) = (x.dim(0), x.dim(1)) + + let qkv = split(wqkv(x), parts: 3, axis: -1) + var queries = qkv[0] + var keys = qkv[1] + var values = qkv[2] + + // prepare the queries, keys and values for the attention computation + queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3) + keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) + values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) + + if let (keyCache, valueCache) = cache { + queries = rope(queries, offset: keyCache.dim(2)) + keys = rope(keys, offset: keyCache.dim(2)) + keys = concatenated([keyCache, keys], axis: 2) + values = concatenated([valueCache, values], axis: 2) + } else { + queries = rope(queries) + keys = rope(keys) + } + + let output = MLXFast.scaledDotProductAttention( + queries: queries, keys: keys, values: values, scale: scale, mask: mask + ) + .transposed(0, 2, 1, 3) + .reshaped(B, L, -1) + + return (wo(output), (keys, values)) + } +} + +private class MLP: Module, UnaryLayer { + + @ModuleInfo(key: "gate_up_proj") var gate_up: Linear + @ModuleInfo(key: "down_proj") var down: Linear + + public init(dimensions: Int, hiddenDimensions: Int) { + self._gate_up.wrappedValue = Linear(dimensions, 2 * hiddenDimensions, bias: false) + self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false) + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + let gu = split(gate_up(x), parts: 2, axis: -1) + return down(silu(gu[0]) * gu[1]) + } +} + +private class TransformerBlock: Module { + + @ModuleInfo(key: "self_attn") var attention: Attention + let mlp: MLP + + @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm + @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm + + public init(_ args: Phi3Configuration) { + self._attention.wrappedValue = Attention(args) + self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize) + self._inputLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + self._postAttentionLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil + ) -> (MLXArray, (MLXArray, MLXArray)) { + var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache) + let h = x + r + r = mlp(postAttentionLayerNorm(h)) + let out = h + r + return (out, cache) + } +} + +public class Phi3ModelInner: Module { + + @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + + fileprivate let layers: [TransformerBlock] + let norm: RMSNorm + + public init(_ args: Phi3Configuration) { + precondition(args.vocabularySize > 0) + + self._embedTokens.wrappedValue = Embedding( + embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) + + self.layers = (0 ..< args.hiddenLayers) + .map { _ in + TransformerBlock(args) + } + self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> ( + MLXArray, [(MLXArray, MLXArray)] + ) { + var h = embedTokens(inputs) + + var mask: MLXArray? = nil + if h.dim(1) > 1 { + mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1)) + mask = mask?.asType(h.dtype) + } + + var newCache = [(MLXArray, MLXArray)]() + + for (i, layer) in layers.enumerated() { + var cacheUpdate: (MLXArray, MLXArray) + (h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i]) + newCache.append(cacheUpdate) + } + + return (norm(h), newCache) + } +} + +public class Phi3Model: Module, LLMModel { + + public let vocabularySize: Int + let model: Phi3ModelInner + + @ModuleInfo(key: "lm_head") var lmHead: Linear + + public init(_ args: Phi3Configuration) { + self.vocabularySize = args.vocabularySize + self.model = Phi3ModelInner(args) + self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> ( + MLXArray, [(MLXArray, MLXArray)] + ) { + let (out, cache) = model(inputs, cache: cache) + return (lmHead(out), cache) + } +} + +public struct Phi3Configuration: Codable { + + var hiddenSize: Int + var hiddenLayers: Int + var intermediateSize: Int + var attentionHeads: Int + var rmsNormEps: Float + var vocabularySize: Int + var kvHeads: Int + var ropeTheta: Float = 10_000 + var ropeTraditional: Bool = false + var ropeScaling: [String: StringOrNumber]? = nil + + enum CodingKeys: String, CodingKey { + case hiddenSize = "hidden_size" + case hiddenLayers = "num_hidden_layers" + case intermediateSize = "intermediate_size" + case attentionHeads = "num_attention_heads" + case rmsNormEps = "rms_norm_eps" + case vocabularySize = "vocab_size" + case kvHeads = "num_key_value_heads" + case ropeTheta = "rope_theta" + case ropeTraditional = "rope_traditional" + case ropeScaling = "rope_scaling" + } + + public init(from decoder: Decoder) throws { + // custom implementation to handle optional keys with required values + let container: KeyedDecodingContainer = + try decoder.container( + keyedBy: Phi3Configuration.CodingKeys.self) + + self.hiddenSize = try container.decode( + Int.self, forKey: Phi3Configuration.CodingKeys.hiddenSize) + self.hiddenLayers = try container.decode( + Int.self, forKey: Phi3Configuration.CodingKeys.hiddenLayers) + self.intermediateSize = try container.decode( + Int.self, forKey: Phi3Configuration.CodingKeys.intermediateSize) + self.attentionHeads = try container.decode( + Int.self, forKey: Phi3Configuration.CodingKeys.attentionHeads) + self.rmsNormEps = try container.decode( + Float.self, forKey: Phi3Configuration.CodingKeys.rmsNormEps) + self.vocabularySize = try container.decode( + Int.self, forKey: Phi3Configuration.CodingKeys.vocabularySize) + self.kvHeads = try container.decode(Int.self, forKey: Phi3Configuration.CodingKeys.kvHeads) + self.ropeTheta = + try container.decodeIfPresent( + Float.self, forKey: Phi3Configuration.CodingKeys.ropeTheta) + ?? 10_000 + self.ropeTraditional = + try container.decodeIfPresent( + Bool.self, forKey: Phi3Configuration.CodingKeys.ropeTraditional) ?? false + self.ropeScaling = try container.decodeIfPresent( + [String: StringOrNumber].self, forKey: Phi3Configuration.CodingKeys.ropeScaling) + + } +} + +// MARK: - LoRA + +extension Phi3Model: LoRAModel { + public func loraLinearLayers() -> LoRALinearLayers { + model.layers.map { ($0.attention, ["qkv_proj"]) } + } +} diff --git a/Tools/llm-tool/LLMTool.swift b/Tools/llm-tool/LLMTool.swift index eaca1e1..7116f32 100644 --- a/Tools/llm-tool/LLMTool.swift +++ b/Tools/llm-tool/LLMTool.swift @@ -53,10 +53,10 @@ struct GenerateArguments: ParsableArguments { var temperature: Float = 0.6 @Option(name: .long, help: "The top p sampling") - var topP: Float = 0.9 + var topP: Float = 1.0 @Option(name: .long, help: "The penalty factor for repeating tokens") - var repetitionPenalty: Float = 1.0 + var repetitionPenalty: Float? @Option(name: .long, help: "The number of tokens to consider for repetition penalty") var repetitionContextSize: Int = 20 diff --git a/mlx-swift-examples.xcodeproj/project.pbxproj b/mlx-swift-examples.xcodeproj/project.pbxproj index 9a28def..17441d8 100644 --- a/mlx-swift-examples.xcodeproj/project.pbxproj +++ b/mlx-swift-examples.xcodeproj/project.pbxproj @@ -8,6 +8,7 @@ /* Begin PBXBuildFile section */ 12305EAF2B9D864400C92FEE /* PredictionView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 12305EAE2B9D864400C92FEE /* PredictionView.swift */; }; + 1CD79C702BD80DE100B6C06F /* Phi3.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1CD79C6F2BD80DE100B6C06F /* Phi3.swift */; }; 525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */; }; 52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 52A776172B94B5EE00AA6E80 /* Qwen2.swift */; }; 81695B412BA373D300F260D8 /* MarkdownUI in Frameworks */ = {isa = PBXBuildFile; productRef = 81695B402BA373D300F260D8 /* MarkdownUI */; }; @@ -216,6 +217,7 @@ /* Begin PBXFileReference section */ 12305EAE2B9D864400C92FEE /* PredictionView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PredictionView.swift; sourceTree = ""; }; + 1CD79C6F2BD80DE100B6C06F /* Phi3.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Phi3.swift; sourceTree = ""; }; 525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Starcoder2.swift; sourceTree = ""; }; 52A776172B94B5EE00AA6E80 /* Qwen2.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Qwen2.swift; sourceTree = ""; }; 819BEFF62BAF8B4E0002CCEE /* DeviceStat.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = DeviceStat.swift; sourceTree = ""; }; @@ -478,6 +480,7 @@ C38935C72B869C7A0037B833 /* LLM.h */, C38935E02B869F420037B833 /* LLMModel.swift */, C38935DE2B869DD00037B833 /* Phi.swift */, + 1CD79C6F2BD80DE100B6C06F /* Phi3.swift */, C34E48F62B69832600FCB841 /* README.md */, C34E48ED2B696E6500FCB841 /* Load.swift */, C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */, @@ -999,6 +1002,7 @@ F24B083A2BAF1A65008C8D19 /* Cohere.swift in Sources */, C38935E32B86C0FE0037B833 /* Gemma.swift in Sources */, C38935CD2B869C870037B833 /* Configuration.swift in Sources */, + 1CD79C702BD80DE100B6C06F /* Phi3.swift in Sources */, 525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */, C36BEFBB2BBF02CC002D4AFE /* Lora+Data.swift in Sources */, C36BEFB02BBCBAC2002D4AFE /* Lora.swift in Sources */, @@ -2341,6 +2345,7 @@ CURRENT_PROJECT_VERSION = 1; DEBUG_INFORMATION_FORMAT = dwarf; DEVELOPMENT_ASSET_PATHS = "\"Applications/LLMEval/Preview Content\""; + DEVELOPMENT_TEAM = ""; ENABLE_PREVIEWS = YES; ENABLE_STRICT_OBJC_MSGSEND = YES; ENABLE_TESTABILITY = YES; @@ -2431,6 +2436,7 @@ CURRENT_PROJECT_VERSION = 1; DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; DEVELOPMENT_ASSET_PATHS = "\"Applications/LLMEval/Preview Content\""; + DEVELOPMENT_TEAM = ""; ENABLE_NS_ASSERTIONS = NO; ENABLE_PREVIEWS = YES; ENABLE_STRICT_OBJC_MSGSEND = YES;