Add Llama 3.1 (#98)

* Update Mistral 7B config

* Add Mistral NeMo

* Update for Llama 3.1

* Align LlamaConfiguration with Python implementation

* Fix model configuration names

* Refine DynamicNTKScalingRoPE

* compute base only once

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Anthony
2024-07-26 22:05:42 +02:00
committed by GitHub
parent c4fda0e036
commit ac6bdfccec
3 changed files with 200 additions and 84 deletions

View File

@@ -159,7 +159,7 @@ class LLMEvaluator {
/// this controls which model loads -- phi4bit is one of the smaller ones so this will fit on
/// more devices
let modelConfiguration = ModelConfiguration.phi34bit
let modelConfiguration = ModelConfiguration.phi3_4bit
/// parameters controlling the output
let generateParameters = GenerateParameters(temperature: 0.6)

View File

@@ -7,6 +7,86 @@ import MLXNN
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py
func computeBaseFrequency(
base: Float, dims: Int, ropeType: String, ropeScaling: [String: StringOrNumber]?
)
-> Float
{
if ropeType != "llama3" {
return base
}
guard let ropeScaling = ropeScaling else {
return base
}
guard case .float(let factor) = ropeScaling["factor"],
case .float(let lowFreqFactor) = ropeScaling["low_freq_factor"] ?? .float(1.0),
case .float(let highFreqFactor) = ropeScaling["high_freq_factor"] ?? .float(4.0),
case .float(let oldContextLen) = ropeScaling["original_max_position_embeddings"]
?? .float(8192)
else {
return base
}
let lowFreqWavelen = oldContextLen / lowFreqFactor
let highFreqWavelen = oldContextLen / highFreqFactor
let freqs = (0 ..< dims).compactMap { index -> Float? in
if index % 2 == 0 {
return pow(base, Float(index) / Float(dims))
}
return nil
}
let newBaseFreqs = freqs.map { freq -> Float in
let wavelen = 2 * .pi / freq
let smooth = max(
0, min(1, (wavelen - highFreqWavelen) / (lowFreqWavelen - highFreqWavelen)))
return freq * ((1 - smooth) * factor + smooth)
}
return newBaseFreqs.reduce(0, +) / Float(newBaseFreqs.count)
}
private class DynamicNTKScalingRoPE: Module {
let dims: Int
let maxPositionEmbeddings: Int?
let traditional: Bool
let base: Float
var scale: Float
let ropeType: String
let ropeScaling: [String: StringOrNumber]?
init(
dims: Int, maxPositionEmbeddings: Int?, traditional: Bool = false,
base: Float = 10000, scale: Float = 1.0, ropeType: String = "default",
ropeScaling: [String: StringOrNumber]? = nil
) {
self.dims = dims
self.maxPositionEmbeddings = maxPositionEmbeddings
self.traditional = traditional
self.base = computeBaseFrequency(
base: base, dims: dims, ropeType: ropeType, ropeScaling: ropeScaling)
self.scale = scale
self.ropeType = ropeType
self.ropeScaling = ropeScaling
}
func callAsFunction(_ x: MLXArray, offset: Int = 0) -> MLXArray {
let seqLen = x.dim(1) + offset
var base = self.base
if let maxPositionEmbeddings, seqLen > maxPositionEmbeddings {
let factorAdjustment = Float(seqLen) / Float(maxPositionEmbeddings) - 1
let dimensionRatio = Float(dims) / Float(Float(dims) - 2)
let adjustedScale = scale * pow(1 + factorAdjustment, dimensionRatio)
base *= adjustedScale
}
return MLXFast.RoPE(
x, dimensions: dims, traditional: traditional, base: base, scale: scale, offset: offset)
}
}
private class Attention: Module {
let args: LlamaConfiguration
@@ -17,9 +97,9 @@ private class Attention: Module {
@ModuleInfo(key: "v_proj") var wv: Linear
@ModuleInfo(key: "o_proj") var wo: Linear
let rope: RoPE
let rope: DynamicNTKScalingRoPE
public init(_ args: LlamaConfiguration) {
init(_ args: LlamaConfiguration) {
self.args = args
let dim = args.hiddenSize
@@ -29,31 +109,28 @@ private class Attention: Module {
let headDim = args.headDimensions ?? (args.hiddenSize / heads)
self.scale = pow(Float(headDim), -0.5)
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false)
self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false)
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: args.attentionBias)
self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: args.attentionBias)
self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: args.attentionBias)
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: args.attentionBias)
let ropeScale: Float
if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"),
let factor = ropeScaling["factor"]
{
switch factor {
case .string:
fatalError("ropeScaling.factor must be a float")
case .float(let v):
ropeScale = 1 / v
}
} else {
ropeScale = 1
}
self.rope = RoPE(
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta,
scale: ropeScale)
self.rope = DynamicNTKScalingRoPE(
dims: headDim,
maxPositionEmbeddings: args.maxPositionEmbeddings,
traditional: args.ropeTraditional,
base: args.ropeTheta,
scale: 1.0,
ropeType: {
if case .string(let value) = args.ropeScaling?["type"] {
return value
} else {
return "default"
}
}(),
ropeScaling: args.ropeScaling)
}
public func callAsFunction(
func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
) -> (MLXArray, (MLXArray, MLXArray)) {
let (B, L) = (x.dim(0), x.dim(1))
@@ -62,7 +139,7 @@ private class Attention: Module {
var keys = wk(x)
var values = wv(x)
// prepare the queries, keys and values for the attention computation
// Prepare the queries, keys and values for the attention computation
queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
@@ -93,35 +170,35 @@ private class MLP: Module, UnaryLayer {
@ModuleInfo(key: "down_proj") var down: Linear
@ModuleInfo(key: "up_proj") var up: Linear
public init(dimensions: Int, hiddenDimensions: Int) {
self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false)
self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
init(_ args: LlamaConfiguration) {
self._gate.wrappedValue = Linear(args.hiddenSize, args.intermediateSize, bias: args.mlpBias)
self._down.wrappedValue = Linear(args.intermediateSize, args.hiddenSize, bias: args.mlpBias)
self._up.wrappedValue = Linear(args.hiddenSize, args.intermediateSize, bias: args.mlpBias)
}
public func callAsFunction(_ x: MLXArray) -> MLXArray {
down(silu(gate(x)) * up(x))
func callAsFunction(_ x: MLXArray) -> MLXArray {
let activation = silu(gate(x))
return down(activation * up(x))
}
}
private class TransformerBlock: Module {
@ModuleInfo(key: "self_attn") var attention: Attention
let mlp: MLP
@ModuleInfo(key: "mlp") var mlp: MLP
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
public init(_ args: LlamaConfiguration) {
init(_ args: LlamaConfiguration) {
self._attention.wrappedValue = Attention(args)
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
self._mlp.wrappedValue = MLP(args)
self._inputLayerNorm.wrappedValue = RMSNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
}
public func callAsFunction(
func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
) -> (MLXArray, (MLXArray, MLXArray)) {
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
@@ -132,27 +209,24 @@ private class TransformerBlock: Module {
}
}
public class LlamaModelInner: Module {
private class LlamaModelInner: Module {
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
fileprivate let layers: [TransformerBlock]
let layers: [TransformerBlock]
let norm: RMSNorm
public init(_ args: LlamaConfiguration) {
init(_ args: LlamaConfiguration) {
precondition(args.vocabularySize > 0)
self._embedTokens.wrappedValue = Embedding(
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
self.layers = (0 ..< args.hiddenLayers)
.map { _ in
TransformerBlock(args)
}
self.layers = (0 ..< args.hiddenLayers).map { _ in TransformerBlock(args) }
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
}
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
MLXArray, [(MLXArray, MLXArray)]
) {
var h = embedTokens(inputs)
@@ -178,7 +252,7 @@ public class LlamaModelInner: Module {
public class LlamaModel: Module, LLMModel {
public let vocabularySize: Int
let model: LlamaModelInner
fileprivate let model: LlamaModelInner
@ModuleInfo(key: "lm_head") var lmHead: Linear?
@@ -202,7 +276,7 @@ public class LlamaModel: Module, LLMModel {
}
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
// Remove unused precomputed rotary freqs
// Remove unused precomputed rotary frequencies
weights.filter {
!$0.key.contains("self_attn.rotary_emb.inv_freq")
}
@@ -215,14 +289,17 @@ public struct LlamaConfiguration: Codable {
var hiddenLayers: Int
var intermediateSize: Int
var attentionHeads: Int
var headDimensions: Int? = nil
var headDimensions: Int?
var rmsNormEps: Float
var vocabularySize: Int
var kvHeads: Int
var maxPositionEmbeddings: Int?
var ropeTheta: Float = 10_000
var ropeTraditional: Bool = false
var ropeScaling: [String: StringOrNumber]? = nil
var tieWordEmbeddings: Bool = false
var ropeScaling: [String: StringOrNumber]?
var tieWordEmbeddings: Bool = true
var attentionBias: Bool = false
var mlpBias: Bool = false
enum CodingKeys: String, CodingKey {
case hiddenSize = "hidden_size"
@@ -233,45 +310,75 @@ public struct LlamaConfiguration: Codable {
case rmsNormEps = "rms_norm_eps"
case vocabularySize = "vocab_size"
case kvHeads = "num_key_value_heads"
case maxPositionEmbeddings = "max_position_embeddings"
case ropeTheta = "rope_theta"
case ropeTraditional = "rope_traditional"
case ropeScaling = "rope_scaling"
case tieWordEmbeddings = "tie_word_embeddings"
case attentionBias = "attention_bias"
case mlpBias = "mlp_bias"
}
public init(from decoder: Decoder) throws {
// custom implementation to handle optional keys with required values
let container: KeyedDecodingContainer<LlamaConfiguration.CodingKeys> =
try decoder.container(
keyedBy: LlamaConfiguration.CodingKeys.self)
let container = try decoder.container(keyedBy: CodingKeys.self)
self.hiddenSize = try container.decode(
Int.self, forKey: LlamaConfiguration.CodingKeys.hiddenSize)
self.hiddenLayers = try container.decode(
Int.self, forKey: LlamaConfiguration.CodingKeys.hiddenLayers)
self.intermediateSize = try container.decode(
Int.self, forKey: LlamaConfiguration.CodingKeys.intermediateSize)
self.attentionHeads = try container.decode(
Int.self, forKey: LlamaConfiguration.CodingKeys.attentionHeads)
self.headDimensions = try container.decodeIfPresent(
Int.self, forKey: LlamaConfiguration.CodingKeys.headDimensions)
self.rmsNormEps = try container.decode(
Float.self, forKey: LlamaConfiguration.CodingKeys.rmsNormEps)
self.vocabularySize = try container.decode(
Int.self, forKey: LlamaConfiguration.CodingKeys.vocabularySize)
self.kvHeads = try container.decode(Int.self, forKey: LlamaConfiguration.CodingKeys.kvHeads)
self.ropeTheta =
try container.decodeIfPresent(
Float.self, forKey: LlamaConfiguration.CodingKeys.ropeTheta)
?? 10_000
self.ropeTraditional =
try container.decodeIfPresent(
Bool.self, forKey: LlamaConfiguration.CodingKeys.ropeTraditional) ?? false
self.ropeScaling = try container.decodeIfPresent(
[String: StringOrNumber].self, forKey: LlamaConfiguration.CodingKeys.ropeScaling)
self.tieWordEmbeddings =
try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false
hiddenSize = try container.decode(Int.self, forKey: .hiddenSize)
hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers)
intermediateSize = try container.decode(Int.self, forKey: .intermediateSize)
attentionHeads = try container.decode(Int.self, forKey: .attentionHeads)
headDimensions = try container.decodeIfPresent(Int.self, forKey: .headDimensions)
rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps)
vocabularySize = try container.decode(Int.self, forKey: .vocabularySize)
kvHeads = try container.decodeIfPresent(Int.self, forKey: .kvHeads) ?? attentionHeads
maxPositionEmbeddings = try container.decodeIfPresent(
Int.self, forKey: .maxPositionEmbeddings)
if let ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) {
self.ropeTheta = ropeTheta
}
if let ropeTraditional = try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional)
{
self.ropeTraditional = ropeTraditional
}
ropeScaling = try container.decodeIfPresent(
[String: StringOrNumber].self, forKey: .ropeScaling)
if let tieWordEmbeddings = try container.decodeIfPresent(
Bool.self, forKey: .tieWordEmbeddings)
{
self.tieWordEmbeddings = tieWordEmbeddings
}
if let attentionBias = try container.decodeIfPresent(Bool.self, forKey: .attentionBias) {
self.attentionBias = attentionBias
}
if let mlpBias = try container.decodeIfPresent(Bool.self, forKey: .mlpBias) {
self.mlpBias = mlpBias
}
if let ropeScaling {
if ropeScaling["factor"] == nil {
throw DecodingError.dataCorruptedError(
forKey: .ropeScaling, in: container,
debugDescription: "rope_scaling must contain 'factor'")
}
if let ropeType = ropeScaling["type"] ?? ropeScaling["rope_type"] {
if case .string = ropeType {
let options = [
StringOrNumber.string("linear"), StringOrNumber.string("dynamic"),
StringOrNumber.string("llama3"),
]
if !options.contains(ropeType) {
throw DecodingError.dataCorruptedError(
forKey: .ropeScaling, in: container,
debugDescription:
"rope_scaling 'type' currently only supports 'linear', 'dynamic', or 'llama3'"
)
}
}
} else {
throw DecodingError.dataCorruptedError(
forKey: .ropeScaling, in: container,
debugDescription: "rope_scaling must contain either 'type' or 'rope_type'")
}
}
}
}

View File

@@ -151,7 +151,7 @@ extension ModelConfiguration {
defaultPrompt: "Why is the sky blue?"
)
public static let phi34bit = ModelConfiguration(
public static let phi3_4bit = ModelConfiguration(
id: "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed",
defaultPrompt: "what is the gravity on mars and the moon?",
extraEOSTokens: ["<|end|>"]
@@ -199,9 +199,17 @@ extension ModelConfiguration {
"\(prompt)"
}
public static let llama38B4bit = ModelConfiguration(
public static let llama3_1_8B_4bit = ModelConfiguration(
id: "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
defaultPrompt: "What is the difference between a fruit and a vegetable?"
) {
prompt in
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
}
public static let llama3_8B_4bit = ModelConfiguration(
id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
defaultPrompt: "what is the difference between a fruit and a vegetable?"
defaultPrompt: "What is the difference between a fruit and a vegetable?"
) {
prompt in
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
@@ -220,12 +228,13 @@ extension ModelConfiguration {
case .idle:
bootstrapState = .bootstrapping
register(configurations: [
llama3_1_8B_4bit,
mistralNeMo4bit,
smolLM_135M_4bit,
mistral7B4bit,
codeLlama13b4bit,
phi4bit,
phi34bit,
phi3_4bit,
gemma2bQuantized,
gemma_2_9b_it_4bit,
qwen205b4bit,