* phi3

Co-authored-by: David Koski <dkoski@apple.com>
This commit is contained in:
Awni Hannun
2024-04-24 09:31:01 -07:00
committed by GitHub
parent 6c0b66f90a
commit b951b78eb2
7 changed files with 284 additions and 8 deletions

View File

@@ -157,7 +157,7 @@ class LLMEvaluator {
/// this controls which model loads -- phi4bit is one of the smaller ones so this will fit on /// this controls which model loads -- phi4bit is one of the smaller ones so this will fit on
/// more devices /// more devices
let modelConfiguration = ModelConfiguration.phi4bit let modelConfiguration = ModelConfiguration.phi34bit
/// parameters controlling the output /// parameters controlling the output
let generateParameters = GenerateParameters(temperature: 0.6) let generateParameters = GenerateParameters(temperature: 0.6)

View File

@@ -30,6 +30,7 @@ public enum ModelType: String, Codable {
case mistral case mistral
case llama case llama
case phi case phi
case phi3
case gemma case gemma
case qwen2 case qwen2
case starcoder2 case starcoder2
@@ -45,6 +46,10 @@ public enum ModelType: String, Codable {
let configuration = try JSONDecoder().decode( let configuration = try JSONDecoder().decode(
PhiConfiguration.self, from: Data(contentsOf: configuration)) PhiConfiguration.self, from: Data(contentsOf: configuration))
return PhiModel(configuration) return PhiModel(configuration)
case .phi3:
let configuration = try JSONDecoder().decode(
Phi3Configuration.self, from: Data(contentsOf: configuration))
return Phi3Model(configuration)
case .gemma: case .gemma:
let configuration = try JSONDecoder().decode( let configuration = try JSONDecoder().decode(
GemmaConfiguration.self, from: Data(contentsOf: configuration)) GemmaConfiguration.self, from: Data(contentsOf: configuration))

View File

@@ -60,16 +60,16 @@ public struct GenerateParameters {
public var temperature: Float = 0.6 public var temperature: Float = 0.6
/// top p sampling /// top p sampling
public var topP: Float = 0.9 public var topP: Float = 1.0
/// penalty factor for repeating tokens /// penalty factor for repeating tokens
public var repetitionPenalty: Float = 1.0 public var repetitionPenalty: Float?
/// number of tokens to consider for repetition penalty /// number of tokens to consider for repetition penalty
public var repetitionContextSize: Int = 20 public var repetitionContextSize: Int = 20
public init( public init(
temperature: Float = 0.6, topP: Float = 0.9, repetitionPenalty: Float = 1.0, temperature: Float = 0.6, topP: Float = 1.0, repetitionPenalty: Float? = nil,
repetitionContextSize: Int = 20 repetitionContextSize: Int = 20
) { ) {
self.temperature = temperature self.temperature = temperature
@@ -111,11 +111,11 @@ public struct TokenIterator: Sequence, IteratorProtocol {
var logits: MLXArray var logits: MLXArray
(logits, cache) = model(expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache) (logits, cache) = model(expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache)
logits = logits[0..., -1, 0...] logits = logits[0..., -1, 0...]
if parameters.repetitionPenalty > 1.0 { if let repetitionPenalty = parameters.repetitionPenalty {
// apply repetition penalty // apply repetition penalty
logits = applyRepetitionPenalty( logits = applyRepetitionPenalty(
logits: logits, repetitionContext: repetitionContext, logits: logits, repetitionContext: repetitionContext,
penalty: parameters.repetitionPenalty) penalty: repetitionPenalty)
} }
y = sample(logits: logits, temp: parameters.temperature, topP: parameters.topP) y = sample(logits: logits, temp: parameters.temperature, topP: parameters.topP)
// append the current token to the context and check repetitionPenalty context see if need to remove the first token // append the current token to the context and check repetitionPenalty context see if need to remove the first token

View File

@@ -116,6 +116,13 @@ extension ModelConfiguration {
"Instruct: \(prompt)\nOutput: " "Instruct: \(prompt)\nOutput: "
} }
public static let phi34bit = ModelConfiguration(
id: "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed"
) {
prompt in
"<s><|user|>\n\(prompt)<|end|>\n<|assistant|>\n"
}
public static let gemma2bQuantized = ModelConfiguration( public static let gemma2bQuantized = ModelConfiguration(
id: "mlx-community/quantized-gemma-2b-it", id: "mlx-community/quantized-gemma-2b-it",
overrideTokenizer: "PreTrainedTokenizer" overrideTokenizer: "PreTrainedTokenizer"
@@ -146,6 +153,7 @@ extension ModelConfiguration {
mistral7B4bit, mistral7B4bit,
codeLlama13b4bit, codeLlama13b4bit,
phi4bit, phi4bit,
phi34bit,
gemma2bQuantized, gemma2bQuantized,
qwen205b4bit, qwen205b4bit,
]) ])

257
Libraries/LLM/Phi3.swift Normal file
View File

@@ -0,0 +1,257 @@
// Copyright © 2024 Apple Inc.
import Foundation
import MLX
import MLXFast
import MLXNN
private class Attention: Module {
let args: Phi3Configuration
let scale: Float
@ModuleInfo(key: "qkv_proj") var wqkv: Linear
@ModuleInfo(key: "o_proj") var wo: Linear
let rope: RoPE
public init(_ args: Phi3Configuration) {
self.args = args
let dim = args.hiddenSize
let heads = args.attentionHeads
let kvHeads = args.kvHeads
let headDim = args.hiddenSize / heads
self.scale = pow(Float(headDim), -0.5)
self._wqkv.wrappedValue = Linear(dim, (heads + 2 * kvHeads) * headDim, bias: false)
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false)
let ropeScale: Float
if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"),
let factor = ropeScaling["factor"]
{
switch factor {
case .string:
fatalError("ropeScaling.factor must be a float")
case .float(let v):
ropeScale = 1 / v
}
} else {
ropeScale = 1
}
self.rope = RoPE(
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta,
scale: ropeScale)
}
public func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
) -> (MLXArray, (MLXArray, MLXArray)) {
let (B, L) = (x.dim(0), x.dim(1))
let qkv = split(wqkv(x), parts: 3, axis: -1)
var queries = qkv[0]
var keys = qkv[1]
var values = qkv[2]
// prepare the queries, keys and values for the attention computation
queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
if let (keyCache, valueCache) = cache {
queries = rope(queries, offset: keyCache.dim(2))
keys = rope(keys, offset: keyCache.dim(2))
keys = concatenated([keyCache, keys], axis: 2)
values = concatenated([valueCache, values], axis: 2)
} else {
queries = rope(queries)
keys = rope(keys)
}
let output = MLXFast.scaledDotProductAttention(
queries: queries, keys: keys, values: values, scale: scale, mask: mask
)
.transposed(0, 2, 1, 3)
.reshaped(B, L, -1)
return (wo(output), (keys, values))
}
}
private class MLP: Module, UnaryLayer {
@ModuleInfo(key: "gate_up_proj") var gate_up: Linear
@ModuleInfo(key: "down_proj") var down: Linear
public init(dimensions: Int, hiddenDimensions: Int) {
self._gate_up.wrappedValue = Linear(dimensions, 2 * hiddenDimensions, bias: false)
self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false)
}
public func callAsFunction(_ x: MLXArray) -> MLXArray {
let gu = split(gate_up(x), parts: 2, axis: -1)
return down(silu(gu[0]) * gu[1])
}
}
private class TransformerBlock: Module {
@ModuleInfo(key: "self_attn") var attention: Attention
let mlp: MLP
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
public init(_ args: Phi3Configuration) {
self._attention.wrappedValue = Attention(args)
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
self._inputLayerNorm.wrappedValue = RMSNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
}
public func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
) -> (MLXArray, (MLXArray, MLXArray)) {
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
let h = x + r
r = mlp(postAttentionLayerNorm(h))
let out = h + r
return (out, cache)
}
}
public class Phi3ModelInner: Module {
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
fileprivate let layers: [TransformerBlock]
let norm: RMSNorm
public init(_ args: Phi3Configuration) {
precondition(args.vocabularySize > 0)
self._embedTokens.wrappedValue = Embedding(
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
self.layers = (0 ..< args.hiddenLayers)
.map { _ in
TransformerBlock(args)
}
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
}
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
MLXArray, [(MLXArray, MLXArray)]
) {
var h = embedTokens(inputs)
var mask: MLXArray? = nil
if h.dim(1) > 1 {
mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1))
mask = mask?.asType(h.dtype)
}
var newCache = [(MLXArray, MLXArray)]()
for (i, layer) in layers.enumerated() {
var cacheUpdate: (MLXArray, MLXArray)
(h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i])
newCache.append(cacheUpdate)
}
return (norm(h), newCache)
}
}
public class Phi3Model: Module, LLMModel {
public let vocabularySize: Int
let model: Phi3ModelInner
@ModuleInfo(key: "lm_head") var lmHead: Linear
public init(_ args: Phi3Configuration) {
self.vocabularySize = args.vocabularySize
self.model = Phi3ModelInner(args)
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
}
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
MLXArray, [(MLXArray, MLXArray)]
) {
let (out, cache) = model(inputs, cache: cache)
return (lmHead(out), cache)
}
}
public struct Phi3Configuration: Codable {
var hiddenSize: Int
var hiddenLayers: Int
var intermediateSize: Int
var attentionHeads: Int
var rmsNormEps: Float
var vocabularySize: Int
var kvHeads: Int
var ropeTheta: Float = 10_000
var ropeTraditional: Bool = false
var ropeScaling: [String: StringOrNumber]? = nil
enum CodingKeys: String, CodingKey {
case hiddenSize = "hidden_size"
case hiddenLayers = "num_hidden_layers"
case intermediateSize = "intermediate_size"
case attentionHeads = "num_attention_heads"
case rmsNormEps = "rms_norm_eps"
case vocabularySize = "vocab_size"
case kvHeads = "num_key_value_heads"
case ropeTheta = "rope_theta"
case ropeTraditional = "rope_traditional"
case ropeScaling = "rope_scaling"
}
public init(from decoder: Decoder) throws {
// custom implementation to handle optional keys with required values
let container: KeyedDecodingContainer<Phi3Configuration.CodingKeys> =
try decoder.container(
keyedBy: Phi3Configuration.CodingKeys.self)
self.hiddenSize = try container.decode(
Int.self, forKey: Phi3Configuration.CodingKeys.hiddenSize)
self.hiddenLayers = try container.decode(
Int.self, forKey: Phi3Configuration.CodingKeys.hiddenLayers)
self.intermediateSize = try container.decode(
Int.self, forKey: Phi3Configuration.CodingKeys.intermediateSize)
self.attentionHeads = try container.decode(
Int.self, forKey: Phi3Configuration.CodingKeys.attentionHeads)
self.rmsNormEps = try container.decode(
Float.self, forKey: Phi3Configuration.CodingKeys.rmsNormEps)
self.vocabularySize = try container.decode(
Int.self, forKey: Phi3Configuration.CodingKeys.vocabularySize)
self.kvHeads = try container.decode(Int.self, forKey: Phi3Configuration.CodingKeys.kvHeads)
self.ropeTheta =
try container.decodeIfPresent(
Float.self, forKey: Phi3Configuration.CodingKeys.ropeTheta)
?? 10_000
self.ropeTraditional =
try container.decodeIfPresent(
Bool.self, forKey: Phi3Configuration.CodingKeys.ropeTraditional) ?? false
self.ropeScaling = try container.decodeIfPresent(
[String: StringOrNumber].self, forKey: Phi3Configuration.CodingKeys.ropeScaling)
}
}
// MARK: - LoRA
extension Phi3Model: LoRAModel {
public func loraLinearLayers() -> LoRALinearLayers {
model.layers.map { ($0.attention, ["qkv_proj"]) }
}
}

View File

@@ -53,10 +53,10 @@ struct GenerateArguments: ParsableArguments {
var temperature: Float = 0.6 var temperature: Float = 0.6
@Option(name: .long, help: "The top p sampling") @Option(name: .long, help: "The top p sampling")
var topP: Float = 0.9 var topP: Float = 1.0
@Option(name: .long, help: "The penalty factor for repeating tokens") @Option(name: .long, help: "The penalty factor for repeating tokens")
var repetitionPenalty: Float = 1.0 var repetitionPenalty: Float?
@Option(name: .long, help: "The number of tokens to consider for repetition penalty") @Option(name: .long, help: "The number of tokens to consider for repetition penalty")
var repetitionContextSize: Int = 20 var repetitionContextSize: Int = 20

View File

@@ -8,6 +8,7 @@
/* Begin PBXBuildFile section */ /* Begin PBXBuildFile section */
12305EAF2B9D864400C92FEE /* PredictionView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 12305EAE2B9D864400C92FEE /* PredictionView.swift */; }; 12305EAF2B9D864400C92FEE /* PredictionView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 12305EAE2B9D864400C92FEE /* PredictionView.swift */; };
1CD79C702BD80DE100B6C06F /* Phi3.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1CD79C6F2BD80DE100B6C06F /* Phi3.swift */; };
525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */; }; 525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */; };
52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 52A776172B94B5EE00AA6E80 /* Qwen2.swift */; }; 52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 52A776172B94B5EE00AA6E80 /* Qwen2.swift */; };
81695B412BA373D300F260D8 /* MarkdownUI in Frameworks */ = {isa = PBXBuildFile; productRef = 81695B402BA373D300F260D8 /* MarkdownUI */; }; 81695B412BA373D300F260D8 /* MarkdownUI in Frameworks */ = {isa = PBXBuildFile; productRef = 81695B402BA373D300F260D8 /* MarkdownUI */; };
@@ -216,6 +217,7 @@
/* Begin PBXFileReference section */ /* Begin PBXFileReference section */
12305EAE2B9D864400C92FEE /* PredictionView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PredictionView.swift; sourceTree = "<group>"; }; 12305EAE2B9D864400C92FEE /* PredictionView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PredictionView.swift; sourceTree = "<group>"; };
1CD79C6F2BD80DE100B6C06F /* Phi3.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Phi3.swift; sourceTree = "<group>"; };
525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Starcoder2.swift; sourceTree = "<group>"; }; 525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Starcoder2.swift; sourceTree = "<group>"; };
52A776172B94B5EE00AA6E80 /* Qwen2.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Qwen2.swift; sourceTree = "<group>"; }; 52A776172B94B5EE00AA6E80 /* Qwen2.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Qwen2.swift; sourceTree = "<group>"; };
819BEFF62BAF8B4E0002CCEE /* DeviceStat.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = DeviceStat.swift; sourceTree = "<group>"; }; 819BEFF62BAF8B4E0002CCEE /* DeviceStat.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = DeviceStat.swift; sourceTree = "<group>"; };
@@ -478,6 +480,7 @@
C38935C72B869C7A0037B833 /* LLM.h */, C38935C72B869C7A0037B833 /* LLM.h */,
C38935E02B869F420037B833 /* LLMModel.swift */, C38935E02B869F420037B833 /* LLMModel.swift */,
C38935DE2B869DD00037B833 /* Phi.swift */, C38935DE2B869DD00037B833 /* Phi.swift */,
1CD79C6F2BD80DE100B6C06F /* Phi3.swift */,
C34E48F62B69832600FCB841 /* README.md */, C34E48F62B69832600FCB841 /* README.md */,
C34E48ED2B696E6500FCB841 /* Load.swift */, C34E48ED2B696E6500FCB841 /* Load.swift */,
C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */, C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */,
@@ -999,6 +1002,7 @@
F24B083A2BAF1A65008C8D19 /* Cohere.swift in Sources */, F24B083A2BAF1A65008C8D19 /* Cohere.swift in Sources */,
C38935E32B86C0FE0037B833 /* Gemma.swift in Sources */, C38935E32B86C0FE0037B833 /* Gemma.swift in Sources */,
C38935CD2B869C870037B833 /* Configuration.swift in Sources */, C38935CD2B869C870037B833 /* Configuration.swift in Sources */,
1CD79C702BD80DE100B6C06F /* Phi3.swift in Sources */,
525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */, 525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */,
C36BEFBB2BBF02CC002D4AFE /* Lora+Data.swift in Sources */, C36BEFBB2BBF02CC002D4AFE /* Lora+Data.swift in Sources */,
C36BEFB02BBCBAC2002D4AFE /* Lora.swift in Sources */, C36BEFB02BBCBAC2002D4AFE /* Lora.swift in Sources */,
@@ -2341,6 +2345,7 @@
CURRENT_PROJECT_VERSION = 1; CURRENT_PROJECT_VERSION = 1;
DEBUG_INFORMATION_FORMAT = dwarf; DEBUG_INFORMATION_FORMAT = dwarf;
DEVELOPMENT_ASSET_PATHS = "\"Applications/LLMEval/Preview Content\""; DEVELOPMENT_ASSET_PATHS = "\"Applications/LLMEval/Preview Content\"";
DEVELOPMENT_TEAM = "";
ENABLE_PREVIEWS = YES; ENABLE_PREVIEWS = YES;
ENABLE_STRICT_OBJC_MSGSEND = YES; ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_TESTABILITY = YES; ENABLE_TESTABILITY = YES;
@@ -2431,6 +2436,7 @@
CURRENT_PROJECT_VERSION = 1; CURRENT_PROJECT_VERSION = 1;
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
DEVELOPMENT_ASSET_PATHS = "\"Applications/LLMEval/Preview Content\""; DEVELOPMENT_ASSET_PATHS = "\"Applications/LLMEval/Preview Content\"";
DEVELOPMENT_TEAM = "";
ENABLE_NS_ASSERTIONS = NO; ENABLE_NS_ASSERTIONS = NO;
ENABLE_PREVIEWS = YES; ENABLE_PREVIEWS = YES;
ENABLE_STRICT_OBJC_MSGSEND = YES; ENABLE_STRICT_OBJC_MSGSEND = YES;