@@ -30,6 +30,7 @@ public enum ModelType: String, Codable {
|
||||
case mistral
|
||||
case llama
|
||||
case phi
|
||||
case phi3
|
||||
case gemma
|
||||
case qwen2
|
||||
case starcoder2
|
||||
@@ -45,6 +46,10 @@ public enum ModelType: String, Codable {
|
||||
let configuration = try JSONDecoder().decode(
|
||||
PhiConfiguration.self, from: Data(contentsOf: configuration))
|
||||
return PhiModel(configuration)
|
||||
case .phi3:
|
||||
let configuration = try JSONDecoder().decode(
|
||||
Phi3Configuration.self, from: Data(contentsOf: configuration))
|
||||
return Phi3Model(configuration)
|
||||
case .gemma:
|
||||
let configuration = try JSONDecoder().decode(
|
||||
GemmaConfiguration.self, from: Data(contentsOf: configuration))
|
||||
|
||||
@@ -60,16 +60,16 @@ public struct GenerateParameters {
|
||||
public var temperature: Float = 0.6
|
||||
|
||||
/// top p sampling
|
||||
public var topP: Float = 0.9
|
||||
public var topP: Float = 1.0
|
||||
|
||||
/// penalty factor for repeating tokens
|
||||
public var repetitionPenalty: Float = 1.0
|
||||
public var repetitionPenalty: Float?
|
||||
|
||||
/// number of tokens to consider for repetition penalty
|
||||
public var repetitionContextSize: Int = 20
|
||||
|
||||
public init(
|
||||
temperature: Float = 0.6, topP: Float = 0.9, repetitionPenalty: Float = 1.0,
|
||||
temperature: Float = 0.6, topP: Float = 1.0, repetitionPenalty: Float? = nil,
|
||||
repetitionContextSize: Int = 20
|
||||
) {
|
||||
self.temperature = temperature
|
||||
@@ -111,11 +111,11 @@ public struct TokenIterator: Sequence, IteratorProtocol {
|
||||
var logits: MLXArray
|
||||
(logits, cache) = model(expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache)
|
||||
logits = logits[0..., -1, 0...]
|
||||
if parameters.repetitionPenalty > 1.0 {
|
||||
if let repetitionPenalty = parameters.repetitionPenalty {
|
||||
// apply repetition penalty
|
||||
logits = applyRepetitionPenalty(
|
||||
logits: logits, repetitionContext: repetitionContext,
|
||||
penalty: parameters.repetitionPenalty)
|
||||
penalty: repetitionPenalty)
|
||||
}
|
||||
y = sample(logits: logits, temp: parameters.temperature, topP: parameters.topP)
|
||||
// append the current token to the context and check repetitionPenalty context see if need to remove the first token
|
||||
|
||||
@@ -116,6 +116,13 @@ extension ModelConfiguration {
|
||||
"Instruct: \(prompt)\nOutput: "
|
||||
}
|
||||
|
||||
public static let phi34bit = ModelConfiguration(
|
||||
id: "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed"
|
||||
) {
|
||||
prompt in
|
||||
"<s><|user|>\n\(prompt)<|end|>\n<|assistant|>\n"
|
||||
}
|
||||
|
||||
public static let gemma2bQuantized = ModelConfiguration(
|
||||
id: "mlx-community/quantized-gemma-2b-it",
|
||||
overrideTokenizer: "PreTrainedTokenizer"
|
||||
@@ -146,6 +153,7 @@ extension ModelConfiguration {
|
||||
mistral7B4bit,
|
||||
codeLlama13b4bit,
|
||||
phi4bit,
|
||||
phi34bit,
|
||||
gemma2bQuantized,
|
||||
qwen205b4bit,
|
||||
])
|
||||
|
||||
257
Libraries/LLM/Phi3.swift
Normal file
257
Libraries/LLM/Phi3.swift
Normal file
@@ -0,0 +1,257 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import Foundation
|
||||
import MLX
|
||||
import MLXFast
|
||||
import MLXNN
|
||||
|
||||
private class Attention: Module {
|
||||
|
||||
let args: Phi3Configuration
|
||||
let scale: Float
|
||||
|
||||
@ModuleInfo(key: "qkv_proj") var wqkv: Linear
|
||||
@ModuleInfo(key: "o_proj") var wo: Linear
|
||||
|
||||
let rope: RoPE
|
||||
|
||||
public init(_ args: Phi3Configuration) {
|
||||
self.args = args
|
||||
|
||||
let dim = args.hiddenSize
|
||||
let heads = args.attentionHeads
|
||||
let kvHeads = args.kvHeads
|
||||
|
||||
let headDim = args.hiddenSize / heads
|
||||
self.scale = pow(Float(headDim), -0.5)
|
||||
|
||||
self._wqkv.wrappedValue = Linear(dim, (heads + 2 * kvHeads) * headDim, bias: false)
|
||||
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false)
|
||||
|
||||
let ropeScale: Float
|
||||
if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"),
|
||||
let factor = ropeScaling["factor"]
|
||||
{
|
||||
switch factor {
|
||||
case .string:
|
||||
fatalError("ropeScaling.factor must be a float")
|
||||
case .float(let v):
|
||||
ropeScale = 1 / v
|
||||
}
|
||||
} else {
|
||||
ropeScale = 1
|
||||
}
|
||||
|
||||
self.rope = RoPE(
|
||||
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta,
|
||||
scale: ropeScale)
|
||||
}
|
||||
|
||||
public func callAsFunction(
|
||||
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
|
||||
) -> (MLXArray, (MLXArray, MLXArray)) {
|
||||
let (B, L) = (x.dim(0), x.dim(1))
|
||||
|
||||
let qkv = split(wqkv(x), parts: 3, axis: -1)
|
||||
var queries = qkv[0]
|
||||
var keys = qkv[1]
|
||||
var values = qkv[2]
|
||||
|
||||
// prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
|
||||
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||
|
||||
if let (keyCache, valueCache) = cache {
|
||||
queries = rope(queries, offset: keyCache.dim(2))
|
||||
keys = rope(keys, offset: keyCache.dim(2))
|
||||
keys = concatenated([keyCache, keys], axis: 2)
|
||||
values = concatenated([valueCache, values], axis: 2)
|
||||
} else {
|
||||
queries = rope(queries)
|
||||
keys = rope(keys)
|
||||
}
|
||||
|
||||
let output = MLXFast.scaledDotProductAttention(
|
||||
queries: queries, keys: keys, values: values, scale: scale, mask: mask
|
||||
)
|
||||
.transposed(0, 2, 1, 3)
|
||||
.reshaped(B, L, -1)
|
||||
|
||||
return (wo(output), (keys, values))
|
||||
}
|
||||
}
|
||||
|
||||
private class MLP: Module, UnaryLayer {
|
||||
|
||||
@ModuleInfo(key: "gate_up_proj") var gate_up: Linear
|
||||
@ModuleInfo(key: "down_proj") var down: Linear
|
||||
|
||||
public init(dimensions: Int, hiddenDimensions: Int) {
|
||||
self._gate_up.wrappedValue = Linear(dimensions, 2 * hiddenDimensions, bias: false)
|
||||
self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false)
|
||||
}
|
||||
|
||||
public func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||
let gu = split(gate_up(x), parts: 2, axis: -1)
|
||||
return down(silu(gu[0]) * gu[1])
|
||||
}
|
||||
}
|
||||
|
||||
private class TransformerBlock: Module {
|
||||
|
||||
@ModuleInfo(key: "self_attn") var attention: Attention
|
||||
let mlp: MLP
|
||||
|
||||
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
|
||||
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
|
||||
|
||||
public init(_ args: Phi3Configuration) {
|
||||
self._attention.wrappedValue = Attention(args)
|
||||
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
|
||||
self._inputLayerNorm.wrappedValue = RMSNorm(
|
||||
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
|
||||
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||
}
|
||||
|
||||
public func callAsFunction(
|
||||
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
|
||||
) -> (MLXArray, (MLXArray, MLXArray)) {
|
||||
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
|
||||
let h = x + r
|
||||
r = mlp(postAttentionLayerNorm(h))
|
||||
let out = h + r
|
||||
return (out, cache)
|
||||
}
|
||||
}
|
||||
|
||||
public class Phi3ModelInner: Module {
|
||||
|
||||
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
|
||||
|
||||
fileprivate let layers: [TransformerBlock]
|
||||
let norm: RMSNorm
|
||||
|
||||
public init(_ args: Phi3Configuration) {
|
||||
precondition(args.vocabularySize > 0)
|
||||
|
||||
self._embedTokens.wrappedValue = Embedding(
|
||||
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
|
||||
|
||||
self.layers = (0 ..< args.hiddenLayers)
|
||||
.map { _ in
|
||||
TransformerBlock(args)
|
||||
}
|
||||
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||
}
|
||||
|
||||
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
|
||||
MLXArray, [(MLXArray, MLXArray)]
|
||||
) {
|
||||
var h = embedTokens(inputs)
|
||||
|
||||
var mask: MLXArray? = nil
|
||||
if h.dim(1) > 1 {
|
||||
mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1))
|
||||
mask = mask?.asType(h.dtype)
|
||||
}
|
||||
|
||||
var newCache = [(MLXArray, MLXArray)]()
|
||||
|
||||
for (i, layer) in layers.enumerated() {
|
||||
var cacheUpdate: (MLXArray, MLXArray)
|
||||
(h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i])
|
||||
newCache.append(cacheUpdate)
|
||||
}
|
||||
|
||||
return (norm(h), newCache)
|
||||
}
|
||||
}
|
||||
|
||||
public class Phi3Model: Module, LLMModel {
|
||||
|
||||
public let vocabularySize: Int
|
||||
let model: Phi3ModelInner
|
||||
|
||||
@ModuleInfo(key: "lm_head") var lmHead: Linear
|
||||
|
||||
public init(_ args: Phi3Configuration) {
|
||||
self.vocabularySize = args.vocabularySize
|
||||
self.model = Phi3ModelInner(args)
|
||||
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
|
||||
}
|
||||
|
||||
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
|
||||
MLXArray, [(MLXArray, MLXArray)]
|
||||
) {
|
||||
let (out, cache) = model(inputs, cache: cache)
|
||||
return (lmHead(out), cache)
|
||||
}
|
||||
}
|
||||
|
||||
public struct Phi3Configuration: Codable {
|
||||
|
||||
var hiddenSize: Int
|
||||
var hiddenLayers: Int
|
||||
var intermediateSize: Int
|
||||
var attentionHeads: Int
|
||||
var rmsNormEps: Float
|
||||
var vocabularySize: Int
|
||||
var kvHeads: Int
|
||||
var ropeTheta: Float = 10_000
|
||||
var ropeTraditional: Bool = false
|
||||
var ropeScaling: [String: StringOrNumber]? = nil
|
||||
|
||||
enum CodingKeys: String, CodingKey {
|
||||
case hiddenSize = "hidden_size"
|
||||
case hiddenLayers = "num_hidden_layers"
|
||||
case intermediateSize = "intermediate_size"
|
||||
case attentionHeads = "num_attention_heads"
|
||||
case rmsNormEps = "rms_norm_eps"
|
||||
case vocabularySize = "vocab_size"
|
||||
case kvHeads = "num_key_value_heads"
|
||||
case ropeTheta = "rope_theta"
|
||||
case ropeTraditional = "rope_traditional"
|
||||
case ropeScaling = "rope_scaling"
|
||||
}
|
||||
|
||||
public init(from decoder: Decoder) throws {
|
||||
// custom implementation to handle optional keys with required values
|
||||
let container: KeyedDecodingContainer<Phi3Configuration.CodingKeys> =
|
||||
try decoder.container(
|
||||
keyedBy: Phi3Configuration.CodingKeys.self)
|
||||
|
||||
self.hiddenSize = try container.decode(
|
||||
Int.self, forKey: Phi3Configuration.CodingKeys.hiddenSize)
|
||||
self.hiddenLayers = try container.decode(
|
||||
Int.self, forKey: Phi3Configuration.CodingKeys.hiddenLayers)
|
||||
self.intermediateSize = try container.decode(
|
||||
Int.self, forKey: Phi3Configuration.CodingKeys.intermediateSize)
|
||||
self.attentionHeads = try container.decode(
|
||||
Int.self, forKey: Phi3Configuration.CodingKeys.attentionHeads)
|
||||
self.rmsNormEps = try container.decode(
|
||||
Float.self, forKey: Phi3Configuration.CodingKeys.rmsNormEps)
|
||||
self.vocabularySize = try container.decode(
|
||||
Int.self, forKey: Phi3Configuration.CodingKeys.vocabularySize)
|
||||
self.kvHeads = try container.decode(Int.self, forKey: Phi3Configuration.CodingKeys.kvHeads)
|
||||
self.ropeTheta =
|
||||
try container.decodeIfPresent(
|
||||
Float.self, forKey: Phi3Configuration.CodingKeys.ropeTheta)
|
||||
?? 10_000
|
||||
self.ropeTraditional =
|
||||
try container.decodeIfPresent(
|
||||
Bool.self, forKey: Phi3Configuration.CodingKeys.ropeTraditional) ?? false
|
||||
self.ropeScaling = try container.decodeIfPresent(
|
||||
[String: StringOrNumber].self, forKey: Phi3Configuration.CodingKeys.ropeScaling)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - LoRA
|
||||
|
||||
extension Phi3Model: LoRAModel {
|
||||
public func loraLinearLayers() -> LoRALinearLayers {
|
||||
model.layers.map { ($0.attention, ["qkv_proj"]) }
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user