* phi3

Co-authored-by: David Koski <dkoski@apple.com>
This commit is contained in:
Awni Hannun
2024-04-24 09:31:01 -07:00
committed by GitHub
parent 6c0b66f90a
commit b951b78eb2
7 changed files with 284 additions and 8 deletions

257
Libraries/LLM/Phi3.swift Normal file
View File

@@ -0,0 +1,257 @@
// Copyright © 2024 Apple Inc.
import Foundation
import MLX
import MLXFast
import MLXNN
private class Attention: Module {
let args: Phi3Configuration
let scale: Float
@ModuleInfo(key: "qkv_proj") var wqkv: Linear
@ModuleInfo(key: "o_proj") var wo: Linear
let rope: RoPE
public init(_ args: Phi3Configuration) {
self.args = args
let dim = args.hiddenSize
let heads = args.attentionHeads
let kvHeads = args.kvHeads
let headDim = args.hiddenSize / heads
self.scale = pow(Float(headDim), -0.5)
self._wqkv.wrappedValue = Linear(dim, (heads + 2 * kvHeads) * headDim, bias: false)
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false)
let ropeScale: Float
if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"),
let factor = ropeScaling["factor"]
{
switch factor {
case .string:
fatalError("ropeScaling.factor must be a float")
case .float(let v):
ropeScale = 1 / v
}
} else {
ropeScale = 1
}
self.rope = RoPE(
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta,
scale: ropeScale)
}
public func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
) -> (MLXArray, (MLXArray, MLXArray)) {
let (B, L) = (x.dim(0), x.dim(1))
let qkv = split(wqkv(x), parts: 3, axis: -1)
var queries = qkv[0]
var keys = qkv[1]
var values = qkv[2]
// prepare the queries, keys and values for the attention computation
queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
if let (keyCache, valueCache) = cache {
queries = rope(queries, offset: keyCache.dim(2))
keys = rope(keys, offset: keyCache.dim(2))
keys = concatenated([keyCache, keys], axis: 2)
values = concatenated([valueCache, values], axis: 2)
} else {
queries = rope(queries)
keys = rope(keys)
}
let output = MLXFast.scaledDotProductAttention(
queries: queries, keys: keys, values: values, scale: scale, mask: mask
)
.transposed(0, 2, 1, 3)
.reshaped(B, L, -1)
return (wo(output), (keys, values))
}
}
private class MLP: Module, UnaryLayer {
@ModuleInfo(key: "gate_up_proj") var gate_up: Linear
@ModuleInfo(key: "down_proj") var down: Linear
public init(dimensions: Int, hiddenDimensions: Int) {
self._gate_up.wrappedValue = Linear(dimensions, 2 * hiddenDimensions, bias: false)
self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false)
}
public func callAsFunction(_ x: MLXArray) -> MLXArray {
let gu = split(gate_up(x), parts: 2, axis: -1)
return down(silu(gu[0]) * gu[1])
}
}
private class TransformerBlock: Module {
@ModuleInfo(key: "self_attn") var attention: Attention
let mlp: MLP
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
public init(_ args: Phi3Configuration) {
self._attention.wrappedValue = Attention(args)
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
self._inputLayerNorm.wrappedValue = RMSNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
}
public func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
) -> (MLXArray, (MLXArray, MLXArray)) {
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
let h = x + r
r = mlp(postAttentionLayerNorm(h))
let out = h + r
return (out, cache)
}
}
public class Phi3ModelInner: Module {
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
fileprivate let layers: [TransformerBlock]
let norm: RMSNorm
public init(_ args: Phi3Configuration) {
precondition(args.vocabularySize > 0)
self._embedTokens.wrappedValue = Embedding(
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
self.layers = (0 ..< args.hiddenLayers)
.map { _ in
TransformerBlock(args)
}
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
}
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
MLXArray, [(MLXArray, MLXArray)]
) {
var h = embedTokens(inputs)
var mask: MLXArray? = nil
if h.dim(1) > 1 {
mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1))
mask = mask?.asType(h.dtype)
}
var newCache = [(MLXArray, MLXArray)]()
for (i, layer) in layers.enumerated() {
var cacheUpdate: (MLXArray, MLXArray)
(h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i])
newCache.append(cacheUpdate)
}
return (norm(h), newCache)
}
}
public class Phi3Model: Module, LLMModel {
public let vocabularySize: Int
let model: Phi3ModelInner
@ModuleInfo(key: "lm_head") var lmHead: Linear
public init(_ args: Phi3Configuration) {
self.vocabularySize = args.vocabularySize
self.model = Phi3ModelInner(args)
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
}
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
MLXArray, [(MLXArray, MLXArray)]
) {
let (out, cache) = model(inputs, cache: cache)
return (lmHead(out), cache)
}
}
public struct Phi3Configuration: Codable {
var hiddenSize: Int
var hiddenLayers: Int
var intermediateSize: Int
var attentionHeads: Int
var rmsNormEps: Float
var vocabularySize: Int
var kvHeads: Int
var ropeTheta: Float = 10_000
var ropeTraditional: Bool = false
var ropeScaling: [String: StringOrNumber]? = nil
enum CodingKeys: String, CodingKey {
case hiddenSize = "hidden_size"
case hiddenLayers = "num_hidden_layers"
case intermediateSize = "intermediate_size"
case attentionHeads = "num_attention_heads"
case rmsNormEps = "rms_norm_eps"
case vocabularySize = "vocab_size"
case kvHeads = "num_key_value_heads"
case ropeTheta = "rope_theta"
case ropeTraditional = "rope_traditional"
case ropeScaling = "rope_scaling"
}
public init(from decoder: Decoder) throws {
// custom implementation to handle optional keys with required values
let container: KeyedDecodingContainer<Phi3Configuration.CodingKeys> =
try decoder.container(
keyedBy: Phi3Configuration.CodingKeys.self)
self.hiddenSize = try container.decode(
Int.self, forKey: Phi3Configuration.CodingKeys.hiddenSize)
self.hiddenLayers = try container.decode(
Int.self, forKey: Phi3Configuration.CodingKeys.hiddenLayers)
self.intermediateSize = try container.decode(
Int.self, forKey: Phi3Configuration.CodingKeys.intermediateSize)
self.attentionHeads = try container.decode(
Int.self, forKey: Phi3Configuration.CodingKeys.attentionHeads)
self.rmsNormEps = try container.decode(
Float.self, forKey: Phi3Configuration.CodingKeys.rmsNormEps)
self.vocabularySize = try container.decode(
Int.self, forKey: Phi3Configuration.CodingKeys.vocabularySize)
self.kvHeads = try container.decode(Int.self, forKey: Phi3Configuration.CodingKeys.kvHeads)
self.ropeTheta =
try container.decodeIfPresent(
Float.self, forKey: Phi3Configuration.CodingKeys.ropeTheta)
?? 10_000
self.ropeTraditional =
try container.decodeIfPresent(
Bool.self, forKey: Phi3Configuration.CodingKeys.ropeTraditional) ?? false
self.ropeScaling = try container.decodeIfPresent(
[String: StringOrNumber].self, forKey: Phi3Configuration.CodingKeys.ropeScaling)
}
}
// MARK: - LoRA
extension Phi3Model: LoRAModel {
public func loraLinearLayers() -> LoRALinearLayers {
model.layers.map { ($0.attention, ["qkv_proj"]) }
}
}