Files
mlx-swift-examples/Libraries/LLM/OpenELM.swift
Sachin Desai 4d20785b5d add support for OpenELM (#63)
* add support for OpenELM
* register model configuration for bootstrap
2024-04-30 09:14:27 -07:00

317 lines
11 KiB
Swift

//
// OpenELM.swift
// LLM
//
// Created by Sachin Desai on 2024/4/27.
//
import Foundation
import MLX
import MLXFast
import MLXNN
func computeHeads(modelDim: Int, headDim: Int) -> Int {
assert(modelDim % headDim == 0, "modelDim must be divisible by headDim")
return modelDim / headDim
}
func makeDivisible(_ v: Float, divisor: Int = 8, minValue: Float? = nil) -> Int {
let minVal = minValue ?? Float(divisor)
var roundDown = max(minVal, Float(Int((v + Float(divisor) / 2) / Float(divisor)) * divisor))
if roundDown < 0.9 * v {
roundDown += Float(divisor)
}
return Int(roundDown)
}
private class MultiHeadCausalAttention: Module {
var args: OpenElmConfiguration
let scale: Float
let heads: Int
let headDim: Int
let kvHeads: Int
@ModuleInfo(key: "qkv_proj") var qkvProj: Linear
@ModuleInfo(key: "out_proj") var outProj: Linear
@ModuleInfo(key: "q_norm") var qNorm: RMSNorm
@ModuleInfo(key: "k_norm") var kNorm: RMSNorm
let rope: RoPE
public init(_ args: OpenElmConfiguration, layerId: Int) {
self.args = args
self.headDim = args.headDimensions
let modelDim = args.modelDim
self.heads = self.args.numQueryHeads[layerId]
self.kvHeads = self.args.kvHeads[layerId]
self.scale = pow(Float(headDim), -0.5)
let opSize = (heads + (kvHeads * 2)) * headDim
self._qkvProj.wrappedValue = Linear(modelDim, opSize, bias: false)
self._outProj.wrappedValue = Linear(heads * headDim, modelDim, bias: false)
if args.normalizeQkProjections {
self._qNorm.wrappedValue = RMSNorm(dimensions: headDim, eps: args.rmsNormEps)
self._kNorm.wrappedValue = RMSNorm(dimensions: headDim, eps: args.rmsNormEps)
}
self.rope = RoPE(
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta)
}
public func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
) -> (MLXArray, (MLXArray, MLXArray)) {
let (B, L) = (x.dim(0), x.dim(1))
let qkv = qkvProj(x).reshaped(B, L, heads + (kvHeads * 2), headDim).transposed(0, 2, 1, 3)
let qkvSplit = split(qkv, indices: [heads, heads + kvHeads], axis: 1)
var queries = qkvSplit[0]
var keys = qkvSplit[1]
var values = qkvSplit[2]
if args.normalizeQkProjections {
queries = qNorm(queries)
keys = kNorm(keys)
}
if let (keyCache, valueCache) = cache {
queries = rope(queries, offset: keyCache.dim(2))
keys = rope(keys, offset: keyCache.dim(2))
keys = concatenated([keyCache, keys], axis: 2)
values = concatenated([valueCache, values], axis: 2)
} else {
queries = rope(queries)
keys = rope(keys)
}
let output = MLXFast.scaledDotProductAttention(
queries: queries, keys: keys, values: values, scale: scale, mask: mask
).transposed(0, 2, 1, 3).reshaped(B, L, heads * headDim)
return (outProj(output), (keys, values))
}
}
private class FeedForwardNetwork: Module, UnaryLayer {
@ModuleInfo var proj_1: Linear
@ModuleInfo var proj_2: Linear
public init(_ args: OpenElmConfiguration, layedId: Int) {
let dim = args.modelDim
let ffnMultiplier = args.ffnMultipliers[layedId]
let intermediateDim = Int(
makeDivisible(Float(ffnMultiplier) * Float(dim), divisor: args.ffnDimDivisor))
self.proj_1 = Linear(dim, 2 * intermediateDim)
self.proj_2 = Linear(intermediateDim, dim)
}
public func callAsFunction(_ x: MLXArray) -> MLXArray {
let a = proj_1(x)
let b = split(a, parts: 2, axis: -1)
let gate = b[0]
let x = b[1]
return proj_2(silu(gate) * x)
}
}
private class TransformerDecoderLayer: Module {
@ModuleInfo(key: "attn") var attn: MultiHeadCausalAttention
let ffn: FeedForwardNetwork
@ModuleInfo(key: "ffn_norm") var ffnNorm: RMSNorm
@ModuleInfo(key: "attn_norm") var attnNorm: RMSNorm
public init(_ args: OpenElmConfiguration, layerId: Int) {
let dim = args.modelDim
self._attn.wrappedValue = MultiHeadCausalAttention(args, layerId: layerId)
self.ffn = FeedForwardNetwork(args, layedId: layerId)
self._ffnNorm.wrappedValue = RMSNorm(dimensions: dim, eps: args.rmsNormEps)
self._attnNorm.wrappedValue = RMSNorm(dimensions: dim, eps: args.rmsNormEps)
}
public func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
) -> (MLXArray, (MLXArray, MLXArray)) {
var (r, cache) = attn(attnNorm(x), mask: mask, cache: cache)
let h = x + r
r = ffn(ffnNorm(h))
let out = h + r
return (out, cache)
}
}
class OpenELMModelInner: Module, LLMModel {
var vocabularySize: Int
@ModuleInfo(key: "token_embeddings") var embedTokens: Embedding
fileprivate let layers: [TransformerDecoderLayer]
fileprivate let norm: RMSNorm
public init(_ args: OpenElmConfiguration) {
precondition(args.vocabularySize > 0)
self.vocabularySize = args.vocabularySize
self._embedTokens.wrappedValue = Embedding(
embeddingCount: self.vocabularySize, dimensions: args.modelDim)
self.layers = (0 ..< args.numTransformerLayers)
.map { layerId in
TransformerDecoderLayer(args, layerId: layerId)
}
self.norm = RMSNorm(dimensions: args.modelDim, eps: args.rmsNormEps)
}
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
MLXArray, [(MLXArray, MLXArray)]
) {
var h = embedTokens(inputs)
var mask: MLXArray? = nil
if h.dim(1) > 1 {
mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1))
mask = mask?.asType(h.dtype)
}
var newCache = [(MLXArray, MLXArray)]()
for (i, layer) in layers.enumerated() {
var cacheUpdate: (MLXArray, MLXArray)
(h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i])
newCache.append(cacheUpdate)
}
return (norm(h), newCache)
}
}
public class OpenELMModel: Module, LLMModel {
public let vocabularySize: Int
let shareInputOutputLayers: Bool
let transformer: OpenELMModelInner
@ModuleInfo(key: "lm_head") var lmHead: Linear
public init(_ args: OpenElmConfiguration) {
self.vocabularySize = args.vocabularySize
self.transformer = OpenELMModelInner(args)
self.shareInputOutputLayers = args.shareInputOutputLayers
self._lmHead.wrappedValue = Linear(
args.numTransformerLayers, args.vocabularySize, bias: false)
}
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
MLXArray, [(MLXArray, MLXArray)]
) {
var (out, cache) = transformer(inputs, cache: cache)
if shareInputOutputLayers {
out = matmul(out, transformer.embedTokens.weight.T)
} else {
out = lmHead(out)
}
return (out, cache)
}
}
public struct OpenElmConfiguration: Codable {
var modelType: String
var headDimensions: Int
var numTransformerLayers: Int
var modelDim: Int
var vocabularySize: Int
var ffnDimDivisor: Int
var numQueryHeads: [Int] = []
var kvHeads: [Int] = []
var ffnWithGlu: Bool = true
var normalizeQkProjections: Bool = true
var shareInputOutputLayers: Bool = true
var rmsNormEps: Float = 1e-6
var ropeTheta: Float = 10_000
var ropeTraditional: Bool = false
var numGqaGroups: Int = 4
var ffnMultipliers: [Float] = [0.5, 4.0]
var qkvMultiplier: [Float] = [0.5, 1.0]
enum CodingKeys: String, CodingKey {
case modelType = "model_type"
case headDimensions = "head_dim"
case numTransformerLayers = "num_transformer_layers"
case modelDim = "model_dim"
case vocabularySize = "vocab_size"
case ffnDimDivisor = "ffn_dim_divisor"
case ffnMultipliers = "ffn_multipliers"
case ffnWithGlu = "ffn_with_glu"
case normalizeQkProjections = "normalize_qk_projections"
case shareInputOutputLayers = "share_input_output_layers"
}
public init(from decoder: Decoder) throws {
// custom implementation to handle optional keys with required values
let container: KeyedDecodingContainer<OpenElmConfiguration.CodingKeys> =
try decoder.container(
keyedBy: OpenElmConfiguration.CodingKeys.self)
self.modelType = try container.decode(
String.self, forKey: OpenElmConfiguration.CodingKeys.modelType)
self.headDimensions = try container.decode(
Int.self, forKey: OpenElmConfiguration.CodingKeys.headDimensions)
self.numTransformerLayers = try container.decode(
Int.self, forKey: OpenElmConfiguration.CodingKeys.numTransformerLayers)
self.modelDim = try container.decode(
Int.self, forKey: OpenElmConfiguration.CodingKeys.modelDim)
self.vocabularySize = try container.decode(
Int.self, forKey: OpenElmConfiguration.CodingKeys.vocabularySize)
self.ffnDimDivisor = try container.decode(
Int.self, forKey: OpenElmConfiguration.CodingKeys.ffnDimDivisor)
let qkvMultipliers = stride(
from: qkvMultiplier[0], through: qkvMultiplier[1],
by: (qkvMultiplier[1] - qkvMultiplier[0]) / Float(numTransformerLayers - 1)
)
.map { round($0 * 100) / 100 }
let headMultipleOf = numGqaGroups
let queryDims = qkvMultipliers.map { a in
makeDivisible(Float(self.modelDim) * a, divisor: self.headDimensions * headMultipleOf)
}
self.numQueryHeads = queryDims.map { qDim in
Int(computeHeads(modelDim: qDim, headDim: self.headDimensions))
}
self.kvHeads = self.numQueryHeads.map { qHeads in
qHeads / numGqaGroups
}
self.ffnMultipliers = stride(
from: ffnMultipliers[0], through: ffnMultipliers[1],
by: (ffnMultipliers[1] - ffnMultipliers[0]) / Float(numTransformerLayers - 1)
)
.map { round($0 * 100) / 100 }
self.ffnWithGlu =
try container.decodeIfPresent(
Bool.self, forKey: OpenElmConfiguration.CodingKeys.ffnWithGlu) ?? true
self.normalizeQkProjections =
try container.decodeIfPresent(
Bool.self, forKey: OpenElmConfiguration.CodingKeys.normalizeQkProjections) ?? true
self.shareInputOutputLayers =
try container.decodeIfPresent(
Bool.self, forKey: OpenElmConfiguration.CodingKeys.shareInputOutputLayers) ?? true
}
}
// MARK: - LoRA
extension OpenELMModel: LoRAModel {
public func loraLinearLayers() -> LoRALinearLayers {
transformer.layers.map { ($0.attn, ["qkv_proj"]) }
}
}