initial commit

This commit is contained in:
David Koski
2024-02-22 10:41:02 -08:00
commit b6d1e14465
29 changed files with 3856 additions and 0 deletions

View File

@@ -0,0 +1,77 @@
// Copyright © 2024 Apple Inc.
import Foundation
public enum StringOrNumber: Codable, Equatable {
case string(String)
case float(Float)
public init(from decoder: Decoder) throws {
let values = try decoder.singleValueContainer()
if let v = try? values.decode(Float.self) {
self = .float(v)
} else {
let v = try values.decode(String.self)
self = .string(v)
}
}
public func encode(to encoder: Encoder) throws {
var container = encoder.singleValueContainer()
switch self {
case .string(let v): try container.encode(v)
case .float(let v): try container.encode(v)
}
}
}
public enum ModelType: String, Codable {
case mistral
case llama
case phi
case gemma
func createModel(configuration: URL) throws -> LLMModel {
switch self {
case .mistral, .llama:
let configuration = try JSONDecoder().decode(
LlamaConfiguration.self, from: Data(contentsOf: configuration))
return LlamaModel(configuration)
case .phi:
let configuration = try JSONDecoder().decode(
PhiConfiguration.self, from: Data(contentsOf: configuration))
return PhiModel(configuration)
case .gemma:
let configuration = try JSONDecoder().decode(
GemmaConfiguration.self, from: Data(contentsOf: configuration))
return GemmaModel(configuration)
}
}
}
public struct BaseConfiguration: Codable {
let modelType: ModelType
public struct Quantization: Codable {
public init(groupSize: Int, bits: Int) {
self.groupSize = groupSize
self.bits = bits
}
let groupSize: Int
let bits: Int
enum CodingKeys: String, CodingKey {
case groupSize = "group_size"
case bits = "bits"
}
}
var quantization: Quantization?
enum CodingKeys: String, CodingKey {
case modelType = "model_type"
case quantization
}
}

273
Libraries/LLM/Gemma.swift Normal file
View File

@@ -0,0 +1,273 @@
// Copyright © 2024 Apple Inc.
import Foundation
import MLX
import MLXNN
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma.py
// specialized norm for gemma
private class RMSNorm: Module, UnaryLayer {
let weight: MLXArray
let eps: Float
public init(dimensions: Int, eps: Float = 1e-5) {
self.weight = MLXArray.ones([dimensions])
self.eps = eps
super.init()
}
func norm(_ x: MLXArray) -> MLXArray {
let S = 1.0 / sqrt(Float(x.dim(-1)))
let n = (x * S).square().sum(axis: -1, keepDims: true)
return rsqrt(n + eps)
}
public func callAsFunction(_ x: MLXArray) -> MLXArray {
let output = norm(x.asType(Float.self)).asType(x.dtype)
return (1 + weight) * output
}
}
private class Attention: Module {
let args: GemmaConfiguration
let repeats: Int
let scale: Float
@ModuleInfo(key: "q_proj") var wq: Linear
@ModuleInfo(key: "k_proj") var wk: Linear
@ModuleInfo(key: "v_proj") var wv: Linear
@ModuleInfo(key: "o_proj") var wo: Linear
let rope: RoPE
public init(_ args: GemmaConfiguration) {
self.args = args
let dim = args.hiddenSize
let heads = args.attentionHeads
let kvHeads = args.kvHeads
self.repeats = heads / kvHeads
let headDim = args.headDimensions
self.scale = pow(Float(headDim), -0.5)
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false)
self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false)
self.rope = RoPE(
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta)
}
public func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
) -> (MLXArray, (MLXArray, MLXArray)) {
let (B, L) = (x.dim(0), x.dim(1))
var queries = wq(x)
var keys = wk(x)
var values = wv(x)
// prepare the queries, keys and values for the attention computation
queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
if repeats > 1 {
keys = MLXArray.repeat(keys, count: repeats, axis: 1)
values = MLXArray.repeat(values, count: repeats, axis: 1)
}
if let (keyCache, valueCache) = cache {
queries = rope(queries, offset: keyCache.dim(2))
keys = rope(keys, offset: keyCache.dim(2))
keys = concatenated([keyCache, keys], axis: 2)
values = concatenated([valueCache, values], axis: 2)
} else {
queries = rope(queries)
keys = rope(keys)
}
var scores = (queries * self.scale).matmul(keys.transposed(0, 1, 3, 2))
if let mask {
scores = scores + mask
}
scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype)
let output = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1)
return (wo(output), (keys, values))
}
}
private class MLP: Module, UnaryLayer {
@ModuleInfo(key: "gate_proj") var gate: Linear
@ModuleInfo(key: "down_proj") var down: Linear
@ModuleInfo(key: "up_proj") var up: Linear
public init(dimensions: Int, hiddenDimensions: Int) {
self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false)
self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
}
public func callAsFunction(_ x: MLXArray) -> MLXArray {
down(gelu(gate(x)) * up(x))
}
}
private class TransformerBlock: Module {
@ModuleInfo(key: "self_attn") var attention: Attention
let mlp: MLP
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
public init(_ args: GemmaConfiguration) {
self._attention.wrappedValue = Attention(args)
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
self._inputLayerNorm.wrappedValue = RMSNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
}
public func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
) -> (MLXArray, (MLXArray, MLXArray)) {
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
let h = x + r
r = mlp(postAttentionLayerNorm(h))
let out = h + r
return (out, cache)
}
}
public class GemmaModelInner: Module {
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
fileprivate let layers: [TransformerBlock]
fileprivate let norm: RMSNorm
let hiddenScale: Float
public init(_ args: GemmaConfiguration) {
precondition(args.vocabularySize > 0)
self._embedTokens.wrappedValue = Embedding(
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
self.hiddenScale = pow(Float(args.hiddenSize), 0.5)
self.layers = (0 ..< args.hiddenLayers)
.map { _ in
TransformerBlock(args)
}
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
}
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
MLXArray, [(MLXArray, MLXArray)]
) {
var h = embedTokens(inputs)
h = h * hiddenScale
var mask: MLXArray? = nil
if h.dim(1) > 1 {
mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1))
mask = mask?.asType(h.dtype)
}
var newCache = [(MLXArray, MLXArray)]()
for (i, layer) in layers.enumerated() {
var cacheUpdate: (MLXArray, MLXArray)
(h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i])
newCache.append(cacheUpdate)
}
return (norm(h), newCache)
}
}
public class GemmaModel: Module, LLMModel {
let model: GemmaModelInner
public init(_ args: GemmaConfiguration) {
self.model = GemmaModelInner(args)
}
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
MLXArray, [(MLXArray, MLXArray)]
) {
var (out, cache) = model(inputs, cache: cache)
out = matmul(out, model.embedTokens.weight.T)
return (out, cache)
}
}
public struct GemmaConfiguration: Codable {
var hiddenSize: Int
var hiddenLayers: Int
var intermediateSize: Int
var attentionHeads: Int
var headDimensions: Int
var rmsNormEps: Float
var vocabularySize: Int
var kvHeads: Int
var ropeTheta: Float = 10_000
var ropeTraditional: Bool = false
enum CodingKeys: String, CodingKey {
case hiddenSize = "hidden_size"
case hiddenLayers = "num_hidden_layers"
case intermediateSize = "intermediate_size"
case attentionHeads = "num_attention_heads"
case headDimensions = "head_dim"
case rmsNormEps = "rms_norm_eps"
case vocabularySize = "vocab_size"
case kvHeads = "num_key_value_heads"
case ropeTheta = "rope_theta"
case ropeTraditional = "rope_traditional"
}
public init(from decoder: Decoder) throws {
// custom implementation to handle optional keys with required values
let container: KeyedDecodingContainer<CodingKeys> = try decoder.container(
keyedBy: CodingKeys.self)
self.hiddenSize = try container.decode(
Int.self, forKey: CodingKeys.hiddenSize)
self.hiddenLayers = try container.decode(
Int.self, forKey: CodingKeys.hiddenLayers)
self.intermediateSize = try container.decode(
Int.self, forKey: CodingKeys.intermediateSize)
self.attentionHeads = try container.decode(
Int.self, forKey: CodingKeys.attentionHeads)
self.headDimensions = try container.decode(
Int.self, forKey: CodingKeys.headDimensions)
self.rmsNormEps = try container.decode(
Float.self, forKey: CodingKeys.rmsNormEps)
self.vocabularySize = try container.decode(
Int.self, forKey: CodingKeys.vocabularySize)
self.kvHeads = try container.decode(Int.self, forKey: CodingKeys.kvHeads)
self.ropeTheta =
try container.decodeIfPresent(Float.self, forKey: CodingKeys.ropeTheta)
?? 10_000
self.ropeTraditional =
try container.decodeIfPresent(
Bool.self, forKey: CodingKeys.ropeTraditional) ?? false
}
}

1
Libraries/LLM/LLM.h Normal file
View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,12 @@
// Copyright © 2024 Apple Inc.
import Foundation
import MLX
import MLXNN
// Interface for all LLM Models
public protocol LLMModel: Module {
func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
MLXArray, [(MLXArray, MLXArray)]
)
}

263
Libraries/LLM/Llama.swift Normal file
View File

@@ -0,0 +1,263 @@
// Copyright © 2024 Apple Inc.
import Foundation
import MLX
import MLXNN
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py
private class Attention: Module {
let args: LlamaConfiguration
let repeats: Int
let scale: Float
@ModuleInfo(key: "q_proj") var wq: Linear
@ModuleInfo(key: "k_proj") var wk: Linear
@ModuleInfo(key: "v_proj") var wv: Linear
@ModuleInfo(key: "o_proj") var wo: Linear
let rope: RoPE
public init(_ args: LlamaConfiguration) {
self.args = args
let dim = args.hiddenSize
let heads = args.attentionHeads
let kvHeads = args.kvHeads
self.repeats = heads / kvHeads
let headDim = args.hiddenSize / heads
self.scale = pow(Float(headDim), -0.5)
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false)
self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false)
let ropeScale: Float
if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"),
let factor = ropeScaling["factor"]
{
switch factor {
case .string:
fatalError("ropeScaling.factor must be a float")
case .float(let v):
ropeScale = 1 / v
}
} else {
ropeScale = 1
}
self.rope = RoPE(
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta,
scale: ropeScale)
}
public func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
) -> (MLXArray, (MLXArray, MLXArray)) {
let (B, L) = (x.dim(0), x.dim(1))
var queries = wq(x)
var keys = wk(x)
var values = wv(x)
// prepare the queries, keys and values for the attention computation
queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
if repeats > 1 {
keys = MLXArray.repeat(keys, count: repeats, axis: 1)
values = MLXArray.repeat(values, count: repeats, axis: 1)
}
if let (keyCache, valueCache) = cache {
queries = rope(queries, offset: keyCache.dim(2))
keys = rope(keys, offset: keyCache.dim(2))
keys = concatenated([keyCache, keys], axis: 2)
values = concatenated([valueCache, values], axis: 2)
} else {
queries = rope(queries)
keys = rope(keys)
}
var scores = (queries * self.scale).matmul(keys.transposed(0, 1, 3, 2))
if let mask {
scores = scores + mask
}
scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype)
let output = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1)
return (wo(output), (keys, values))
}
}
private class MLP: Module, UnaryLayer {
@ModuleInfo(key: "gate_proj") var gate: Linear
@ModuleInfo(key: "down_proj") var down: Linear
@ModuleInfo(key: "up_proj") var up: Linear
public init(dimensions: Int, hiddenDimensions: Int) {
self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false)
self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
}
public func callAsFunction(_ x: MLXArray) -> MLXArray {
down(silu(gate(x)) * up(x))
}
}
private class TransformerBlock: Module {
@ModuleInfo(key: "self_attn") var attention: Attention
let mlp: MLP
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
public init(_ args: LlamaConfiguration) {
self._attention.wrappedValue = Attention(args)
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
self._inputLayerNorm.wrappedValue = RMSNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
}
public func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
) -> (MLXArray, (MLXArray, MLXArray)) {
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
let h = x + r
r = mlp(postAttentionLayerNorm(h))
let out = h + r
return (out, cache)
}
}
public class LlamaModelInner: Module {
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
fileprivate let layers: [TransformerBlock]
let norm: RMSNorm
public init(_ args: LlamaConfiguration) {
precondition(args.vocabularySize > 0)
self._embedTokens.wrappedValue = Embedding(
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
self.layers = (0 ..< args.hiddenLayers)
.map { _ in
TransformerBlock(args)
}
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
}
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
MLXArray, [(MLXArray, MLXArray)]
) {
var h = embedTokens(inputs)
var mask: MLXArray? = nil
if h.dim(1) > 1 {
mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1))
mask = mask?.asType(h.dtype)
}
var newCache = [(MLXArray, MLXArray)]()
for (i, layer) in layers.enumerated() {
var cacheUpdate: (MLXArray, MLXArray)
(h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i])
newCache.append(cacheUpdate)
}
return (norm(h), newCache)
}
}
public class LlamaModel: Module, LLMModel {
let model: LlamaModelInner
@ModuleInfo(key: "lm_head") var lmHead: Linear
public init(_ args: LlamaConfiguration) {
self.model = LlamaModelInner(args)
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
}
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
MLXArray, [(MLXArray, MLXArray)]
) {
let (out, cache) = model(inputs, cache: cache)
return (lmHead(out), cache)
}
}
public struct LlamaConfiguration: Codable {
var hiddenSize: Int
var hiddenLayers: Int
var intermediateSize: Int
var attentionHeads: Int
var rmsNormEps: Float
var vocabularySize: Int
var kvHeads: Int
var ropeTheta: Float = 10_000
var ropeTraditional: Bool = false
var ropeScaling: [String: StringOrNumber]? = nil
enum CodingKeys: String, CodingKey {
case hiddenSize = "hidden_size"
case hiddenLayers = "num_hidden_layers"
case intermediateSize = "intermediate_size"
case attentionHeads = "num_attention_heads"
case rmsNormEps = "rms_norm_eps"
case vocabularySize = "vocab_size"
case kvHeads = "num_key_value_heads"
case ropeTheta = "rope_theta"
case ropeTraditional = "rope_traditional"
case ropeScaling = "rope_scaling"
}
public init(from decoder: Decoder) throws {
// custom implementation to handle optional keys with required values
let container: KeyedDecodingContainer<LlamaConfiguration.CodingKeys> =
try decoder.container(
keyedBy: LlamaConfiguration.CodingKeys.self)
self.hiddenSize = try container.decode(
Int.self, forKey: LlamaConfiguration.CodingKeys.hiddenSize)
self.hiddenLayers = try container.decode(
Int.self, forKey: LlamaConfiguration.CodingKeys.hiddenLayers)
self.intermediateSize = try container.decode(
Int.self, forKey: LlamaConfiguration.CodingKeys.intermediateSize)
self.attentionHeads = try container.decode(
Int.self, forKey: LlamaConfiguration.CodingKeys.attentionHeads)
self.rmsNormEps = try container.decode(
Float.self, forKey: LlamaConfiguration.CodingKeys.rmsNormEps)
self.vocabularySize = try container.decode(
Int.self, forKey: LlamaConfiguration.CodingKeys.vocabularySize)
self.kvHeads = try container.decode(Int.self, forKey: LlamaConfiguration.CodingKeys.kvHeads)
self.ropeTheta =
try container.decodeIfPresent(
Float.self, forKey: LlamaConfiguration.CodingKeys.ropeTheta)
?? 10_000
self.ropeTraditional =
try container.decodeIfPresent(
Bool.self, forKey: LlamaConfiguration.CodingKeys.ropeTraditional) ?? false
self.ropeScaling = try container.decodeIfPresent(
[String: StringOrNumber].self, forKey: LlamaConfiguration.CodingKeys.ropeScaling)
}
}

302
Libraries/LLM/Phi.swift Normal file
View File

@@ -0,0 +1,302 @@
// Copyright © 2024 Apple Inc.
import Foundation
import MLX
import MLXNN
// https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/phi.py
// TODO: remove once open classes are in
public class MLXLayerNorm: Module, UnaryLayer {
let dimensions: Int
let eps: Float
let weight: MLXArray?
let bias: MLXArray?
/// Applies layer normalization [1] on the inputs.
///
/// See [LayerNorm python docs](https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.LayerNorm.html) for more information.
///
/// ### References
/// 1. [https://arxiv.org/abs/1607.06450](https://arxiv.org/abs/1607.06450)
///
/// - Parameters:
/// - dimensions: number of features in the input
/// - eps: value added to the denominator for numerical stability
/// - affine: if `true` adds a trainable `weight` and `bias`
public init(dimensions: Int, eps: Float = 1e-5, affine: Bool = true) {
self.dimensions = dimensions
self.eps = eps
if affine {
self.weight = MLXArray.ones([dimensions])
self.bias = MLXArray.zeros([dimensions])
} else {
self.weight = nil
self.bias = nil
}
}
public func callAsFunction(_ x: MLXArray) -> MLXArray {
let means = mean(x, axis: -1, keepDims: true)
let variance = variance(x, axis: -1, keepDims: true)
let x = (x - means) * rsqrt(variance + eps)
if let weight, let bias {
return weight * x + bias
} else {
return x
}
}
}
private class LayerNorm: MLXLayerNorm {
override func callAsFunction(_ x: MLXArray) -> MLXArray {
super.callAsFunction(x.asType(Float.self)).asType(x.dtype)
}
}
private class PhiAttention: Module {
let args: PhiConfiguration
let heads: Int
let headDim: Int
let repeats: Int
@ModuleInfo(key: "q_proj") var wq: Linear
@ModuleInfo(key: "k_proj") var wk: Linear
@ModuleInfo(key: "v_proj") var wv: Linear
@ModuleInfo(key: "dense") var dense: Linear
let rope: RoPE
public init(_ args: PhiConfiguration) {
self.args = args
let hiddenSize = args.hiddenSize
self.heads = args.attentionHeads
self.headDim = args.hiddenSize / heads
let kvHeads = args.kvHeads
self.repeats = heads / kvHeads
if headDim * heads != hiddenSize {
fatalError("hidden_size must be divisible by num_heads")
}
self._wq.wrappedValue = Linear(hiddenSize, heads * headDim, bias: true)
self._wk.wrappedValue = Linear(hiddenSize, kvHeads * headDim, bias: true)
self._wv.wrappedValue = Linear(hiddenSize, kvHeads * headDim, bias: true)
self._dense.wrappedValue = Linear(heads * headDim, hiddenSize, bias: true)
self.rope = RoPE(
dimensions: Int(args.partialRotaryFactor * Float(headDim)), traditional: false,
base: args.ropeTheta)
}
public func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
) -> (MLXArray, (MLXArray, MLXArray)) {
let (B, L) = (x.dim(0), x.dim(1))
var queries = wq(x)
var keys = wk(x)
var values = wv(x)
// prepare the queries, keys and values for the attention computation
queries = queries.reshaped(B, L, heads, headDim).transposed(0, 2, 1, 3)
keys = keys.reshaped(B, L, args.kvHeads, headDim).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, args.kvHeads, headDim).transposed(0, 2, 1, 3)
if repeats > 1 {
keys = MLXArray.repeat(keys, count: repeats, axis: 1)
values = MLXArray.repeat(values, count: repeats, axis: 1)
}
// Add RoPE to the queries and keys and combine them with the cache
if let (keyCache, valueCache) = cache {
queries = rope(queries, offset: keyCache.dim(2))
keys = rope(keys, offset: keyCache.dim(2))
keys = concatenated([keyCache, keys], axis: 2)
values = concatenated([valueCache, values], axis: 2)
} else {
queries = rope(queries)
keys = rope(keys)
}
queries = queries.asType(Float.self)
keys = keys.asType(Float.self)
// Finally perform the attention computation
let scale = sqrt(1 / Float(queries.dim(-1)))
var scores = (queries * scale).matmul(keys.transposed(0, 1, 3, 2))
if let mask {
scores = scores + mask
}
scores = softMax(scores, axis: -1).asType(values.dtype)
let valuesHat = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1)
return (dense(valuesHat), (keys, values))
}
}
private class PhiMLP: Module, UnaryLayer {
@ModuleInfo var fc1: Linear
@ModuleInfo var fc2: Linear
@ModuleInfo var act: GELU
public init(_ config: PhiConfiguration) {
self.fc1 = Linear(config.hiddenSize, config.intermediateSize)
self.fc2 = Linear(config.intermediateSize, config.hiddenSize)
self.act = GELU(approximation: .precise)
}
public func callAsFunction(_ x: MLXArray) -> MLXArray {
fc2(act(fc1(x)))
}
}
private class PhiDecoderLayer: Module {
@ModuleInfo(key: "self_attn") var selfAttention: PhiAttention
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: LayerNorm
var mlp: PhiMLP
public init(_ config: PhiConfiguration) {
self._selfAttention.wrappedValue = PhiAttention(config)
self._inputLayerNorm.wrappedValue = LayerNorm(
dimensions: config.hiddenSize, eps: config.layerNormEps)
self.mlp = PhiMLP(config)
}
public func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
) -> (MLXArray, (MLXArray, MLXArray)) {
let h = inputLayerNorm(x)
let (attentionH, cache) = selfAttention(h, mask: mask, cache: cache)
let ffH = mlp(h)
return (attentionH + ffH + x, cache)
}
}
private class PhiModelInner: Module {
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
@ModuleInfo var layers: [PhiDecoderLayer]
@ModuleInfo(key: "final_layernorm") var finalLayerNorm: LayerNorm
public init(_ args: PhiConfiguration) {
self._embedTokens.wrappedValue = Embedding(
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
self.layers = (0 ..< args.hiddenLayers)
.map { _ in
PhiDecoderLayer(args)
}
self._finalLayerNorm.wrappedValue = LayerNorm(
dimensions: args.hiddenSize, eps: args.layerNormEps)
}
public func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: [(MLXArray, MLXArray)]? = nil
) -> (
MLXArray, [(MLXArray, MLXArray)]
) {
var x = embedTokens(x)
var newCache = [(MLXArray, MLXArray)]()
for (i, layer) in layers.enumerated() {
var cacheUpdate: (MLXArray, MLXArray)
(x, cacheUpdate) = layer(x, mask: mask, cache: cache?[i])
newCache.append(cacheUpdate)
}
return (finalLayerNorm(x), newCache)
}
}
public class PhiModel: Module, LLMModel {
fileprivate let model: PhiModelInner
@ModuleInfo(key: "lm_head") var lmHead: Linear
public init(_ args: PhiConfiguration) {
self.model = PhiModelInner(args)
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: true)
}
public func callAsFunction(_ x: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
MLXArray, [(MLXArray, MLXArray)]
) {
var mask: MLXArray? = nil
if x.dim(1) > 1 {
mask = MultiHeadAttention.createAdditiveCausalMask(x.dim(1))
mask = mask?.asType(x.dtype)
}
let (y, cache) = model(x, mask: mask, cache: cache)
return (lmHead(y), cache)
}
}
public struct PhiConfiguration: Codable {
var maxPositionalEmbeddings = 2048
var vocabularySize = 51200
var hiddenSize = 2560
var attentionHeads = 32
var hiddenLayers = 32
var kvHeads = 32
var partialRotaryFactor: Float = 0.4
var intermediateSize = 10240
var layerNormEps: Float = 1e-5
var ropeTheta: Float = 10_000
enum CodingKeys: String, CodingKey {
case maxPositionalEmbeddings = "max_position_embeddings"
case vocabularySize = "vocab_size"
case hiddenSize = "hidden_size"
case attentionHeads = "num_attention_heads"
case hiddenLayers = "num_hidden_layers"
case kvHeads = "num_key_value_heads"
case partialRotaryFactor = "partial_rotary_factor"
case intermediateSize = "intermediate_size"
case layerNormEps = "layer_norm_eps"
case ropeTheta = "rope_theta"
}
public init(from decoder: Decoder) throws {
let container: KeyedDecodingContainer<PhiConfiguration.CodingKeys> = try decoder.container(
keyedBy: PhiConfiguration.CodingKeys.self)
self.maxPositionalEmbeddings = try container.decode(
Int.self, forKey: PhiConfiguration.CodingKeys.maxPositionalEmbeddings)
self.vocabularySize = try container.decode(
Int.self, forKey: PhiConfiguration.CodingKeys.vocabularySize)
self.hiddenSize = try container.decode(
Int.self, forKey: PhiConfiguration.CodingKeys.hiddenSize)
self.attentionHeads = try container.decode(
Int.self, forKey: PhiConfiguration.CodingKeys.attentionHeads)
self.hiddenLayers = try container.decode(
Int.self, forKey: PhiConfiguration.CodingKeys.hiddenLayers)
self.kvHeads =
try container.decodeIfPresent(Int.self, forKey: PhiConfiguration.CodingKeys.kvHeads)
?? attentionHeads
self.partialRotaryFactor = try container.decode(
Float.self, forKey: PhiConfiguration.CodingKeys.partialRotaryFactor)
self.intermediateSize = try container.decode(
Int.self, forKey: PhiConfiguration.CodingKeys.intermediateSize)
self.layerNormEps = try container.decode(
Float.self, forKey: PhiConfiguration.CodingKeys.layerNormEps)
self.ropeTheta =
try container.decodeIfPresent(Float.self, forKey: PhiConfiguration.CodingKeys.ropeTheta)
?? 10_000
}
}

11
Libraries/LLM/README.md Normal file
View File

@@ -0,0 +1,11 @@
# Llama
This is a port of the llama model from:
- https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py
You can use this to load models from huggingface, e.g.:
- https://huggingface.co/mlx-community/Mistral-7B-v0.1-hf-4bit-mlx
See [llm-tool](../../Tools/llm-tool)

110
Libraries/LLM/Util.swift Normal file
View File

@@ -0,0 +1,110 @@
// Copyright © 2024 Apple Inc.
import AsyncAlgorithms
import Foundation
import Hub
import MLX
import MLXNN
import MLXRandom
import Tokenizers
/// Load and return the model and tokenizer
public func load(
hub: HubApi = HubApi(), name: String, progressHandler: @escaping (Progress) -> Void = { _ in }
) async throws -> (LLMModel, Tokenizer) {
// note: this doesn't have a way to pass the HubApi
let tokenizer = try await AutoTokenizer.from(pretrained: name)
// download the model weights and config
let repo = Hub.Repo(id: name)
let modelFiles = ["config.json", "weights.00.safetensors"]
let modelDirectory = try await hub.snapshot(
from: repo, matching: modelFiles, progressHandler: progressHandler)
// create the model (no weights loaded)
let configurationURL = modelDirectory.appending(component: "config.json")
let baseConfig = try JSONDecoder().decode(
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
let model = try baseConfig.modelType.createModel(configuration: configurationURL)
// set up the model
if let quantization = baseConfig.quantization {
QuantizedLinear.quantize(
model: model, groupSize: quantization.groupSize, bits: quantization.bits)
}
// apply the loaded weights
let weights = try loadArrays(url: modelDirectory.appending(component: "weights.00.safetensors"))
let parameters = ModuleParameters.unflattened(weights)
try model.update(parameters: parameters, verify: [.all])
eval(model.parameters())
return (model, tokenizer)
}
private func sample(logits: MLXArray, temp: Float) -> MLXArray {
if temp == 0 {
return argMax(logits, axis: -1)
} else {
return categorical(logits * (1 / temp))
}
}
/// Synchronous generator of tokens.
///
/// Port of `generate_step()` from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py
public struct TokenIterator: Sequence, IteratorProtocol {
let model: LLMModel
let temp: Float
var y: MLXArray
var cache: [(MLXArray, MLXArray)]
var first = true
public init(prompt: MLXArray, model: LLMModel, temp: Float = 0.0) {
self.model = model
self.temp = temp
self.y = prompt
self.cache = []
}
mutating public func next() -> MLXArray? {
var logits: MLXArray
(logits, cache) = model(expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache)
y = sample(logits: logits[-1, axis: 1], temp: temp)
return y
}
}
/// Async generator of tokens.
///
/// Port of `generate_step()` from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py.
///
/// Note that because MLXArray is not thread safe this eval's the result and sends the TokenId back
/// to the caller.
public func generate(prompt: MLXArray, model: LLMModel, temp: Float = 0.0) -> (
Task<Void, Never>, AsyncBufferSequence<AsyncChannel<Int>>
) {
let channel = AsyncChannel<Int>()
let buffer = channel.buffer(policy: .bounded(10))
let task = Task {
var y = prompt
var cache = [(MLXArray, MLXArray)]()
while !Task.isCancelled {
var logits: MLXArray
(logits, cache) = model(
expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache)
y = sample(logits: logits[-1, axis: 1], temp: temp)
eval(y)
await channel.send(y.item(Int.self))
}
}
return (task, buffer)
}