Add Llama 3.1 (#98)
* Update Mistral 7B config * Add Mistral NeMo * Update for Llama 3.1 * Align LlamaConfiguration with Python implementation * Fix model configuration names * Refine DynamicNTKScalingRoPE * compute base only once --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -159,7 +159,7 @@ class LLMEvaluator {
|
|||||||
|
|
||||||
/// this controls which model loads -- phi4bit is one of the smaller ones so this will fit on
|
/// this controls which model loads -- phi4bit is one of the smaller ones so this will fit on
|
||||||
/// more devices
|
/// more devices
|
||||||
let modelConfiguration = ModelConfiguration.phi34bit
|
let modelConfiguration = ModelConfiguration.phi3_4bit
|
||||||
|
|
||||||
/// parameters controlling the output
|
/// parameters controlling the output
|
||||||
let generateParameters = GenerateParameters(temperature: 0.6)
|
let generateParameters = GenerateParameters(temperature: 0.6)
|
||||||
|
|||||||
@@ -7,6 +7,86 @@ import MLXNN
|
|||||||
|
|
||||||
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py
|
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py
|
||||||
|
|
||||||
|
func computeBaseFrequency(
|
||||||
|
base: Float, dims: Int, ropeType: String, ropeScaling: [String: StringOrNumber]?
|
||||||
|
)
|
||||||
|
-> Float
|
||||||
|
{
|
||||||
|
if ropeType != "llama3" {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
guard let ropeScaling = ropeScaling else {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
guard case .float(let factor) = ropeScaling["factor"],
|
||||||
|
case .float(let lowFreqFactor) = ropeScaling["low_freq_factor"] ?? .float(1.0),
|
||||||
|
case .float(let highFreqFactor) = ropeScaling["high_freq_factor"] ?? .float(4.0),
|
||||||
|
case .float(let oldContextLen) = ropeScaling["original_max_position_embeddings"]
|
||||||
|
?? .float(8192)
|
||||||
|
else {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
let lowFreqWavelen = oldContextLen / lowFreqFactor
|
||||||
|
let highFreqWavelen = oldContextLen / highFreqFactor
|
||||||
|
|
||||||
|
let freqs = (0 ..< dims).compactMap { index -> Float? in
|
||||||
|
if index % 2 == 0 {
|
||||||
|
return pow(base, Float(index) / Float(dims))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
let newBaseFreqs = freqs.map { freq -> Float in
|
||||||
|
let wavelen = 2 * .pi / freq
|
||||||
|
let smooth = max(
|
||||||
|
0, min(1, (wavelen - highFreqWavelen) / (lowFreqWavelen - highFreqWavelen)))
|
||||||
|
return freq * ((1 - smooth) * factor + smooth)
|
||||||
|
}
|
||||||
|
|
||||||
|
return newBaseFreqs.reduce(0, +) / Float(newBaseFreqs.count)
|
||||||
|
}
|
||||||
|
|
||||||
|
private class DynamicNTKScalingRoPE: Module {
|
||||||
|
let dims: Int
|
||||||
|
let maxPositionEmbeddings: Int?
|
||||||
|
let traditional: Bool
|
||||||
|
let base: Float
|
||||||
|
var scale: Float
|
||||||
|
let ropeType: String
|
||||||
|
let ropeScaling: [String: StringOrNumber]?
|
||||||
|
|
||||||
|
init(
|
||||||
|
dims: Int, maxPositionEmbeddings: Int?, traditional: Bool = false,
|
||||||
|
base: Float = 10000, scale: Float = 1.0, ropeType: String = "default",
|
||||||
|
ropeScaling: [String: StringOrNumber]? = nil
|
||||||
|
) {
|
||||||
|
self.dims = dims
|
||||||
|
self.maxPositionEmbeddings = maxPositionEmbeddings
|
||||||
|
self.traditional = traditional
|
||||||
|
self.base = computeBaseFrequency(
|
||||||
|
base: base, dims: dims, ropeType: ropeType, ropeScaling: ropeScaling)
|
||||||
|
self.scale = scale
|
||||||
|
self.ropeType = ropeType
|
||||||
|
self.ropeScaling = ropeScaling
|
||||||
|
}
|
||||||
|
|
||||||
|
func callAsFunction(_ x: MLXArray, offset: Int = 0) -> MLXArray {
|
||||||
|
let seqLen = x.dim(1) + offset
|
||||||
|
var base = self.base
|
||||||
|
if let maxPositionEmbeddings, seqLen > maxPositionEmbeddings {
|
||||||
|
let factorAdjustment = Float(seqLen) / Float(maxPositionEmbeddings) - 1
|
||||||
|
let dimensionRatio = Float(dims) / Float(Float(dims) - 2)
|
||||||
|
let adjustedScale = scale * pow(1 + factorAdjustment, dimensionRatio)
|
||||||
|
base *= adjustedScale
|
||||||
|
}
|
||||||
|
return MLXFast.RoPE(
|
||||||
|
x, dimensions: dims, traditional: traditional, base: base, scale: scale, offset: offset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private class Attention: Module {
|
private class Attention: Module {
|
||||||
|
|
||||||
let args: LlamaConfiguration
|
let args: LlamaConfiguration
|
||||||
@@ -17,9 +97,9 @@ private class Attention: Module {
|
|||||||
@ModuleInfo(key: "v_proj") var wv: Linear
|
@ModuleInfo(key: "v_proj") var wv: Linear
|
||||||
@ModuleInfo(key: "o_proj") var wo: Linear
|
@ModuleInfo(key: "o_proj") var wo: Linear
|
||||||
|
|
||||||
let rope: RoPE
|
let rope: DynamicNTKScalingRoPE
|
||||||
|
|
||||||
public init(_ args: LlamaConfiguration) {
|
init(_ args: LlamaConfiguration) {
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
let dim = args.hiddenSize
|
let dim = args.hiddenSize
|
||||||
@@ -29,31 +109,28 @@ private class Attention: Module {
|
|||||||
let headDim = args.headDimensions ?? (args.hiddenSize / heads)
|
let headDim = args.headDimensions ?? (args.hiddenSize / heads)
|
||||||
self.scale = pow(Float(headDim), -0.5)
|
self.scale = pow(Float(headDim), -0.5)
|
||||||
|
|
||||||
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false)
|
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: args.attentionBias)
|
||||||
self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
|
self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: args.attentionBias)
|
||||||
self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
|
self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: args.attentionBias)
|
||||||
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false)
|
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: args.attentionBias)
|
||||||
|
|
||||||
let ropeScale: Float
|
self.rope = DynamicNTKScalingRoPE(
|
||||||
if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"),
|
dims: headDim,
|
||||||
let factor = ropeScaling["factor"]
|
maxPositionEmbeddings: args.maxPositionEmbeddings,
|
||||||
{
|
traditional: args.ropeTraditional,
|
||||||
switch factor {
|
base: args.ropeTheta,
|
||||||
case .string:
|
scale: 1.0,
|
||||||
fatalError("ropeScaling.factor must be a float")
|
ropeType: {
|
||||||
case .float(let v):
|
if case .string(let value) = args.ropeScaling?["type"] {
|
||||||
ropeScale = 1 / v
|
return value
|
||||||
}
|
} else {
|
||||||
} else {
|
return "default"
|
||||||
ropeScale = 1
|
}
|
||||||
}
|
}(),
|
||||||
|
ropeScaling: args.ropeScaling)
|
||||||
self.rope = RoPE(
|
|
||||||
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta,
|
|
||||||
scale: ropeScale)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public func callAsFunction(
|
func callAsFunction(
|
||||||
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
|
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
|
||||||
) -> (MLXArray, (MLXArray, MLXArray)) {
|
) -> (MLXArray, (MLXArray, MLXArray)) {
|
||||||
let (B, L) = (x.dim(0), x.dim(1))
|
let (B, L) = (x.dim(0), x.dim(1))
|
||||||
@@ -62,7 +139,7 @@ private class Attention: Module {
|
|||||||
var keys = wk(x)
|
var keys = wk(x)
|
||||||
var values = wv(x)
|
var values = wv(x)
|
||||||
|
|
||||||
// prepare the queries, keys and values for the attention computation
|
// Prepare the queries, keys and values for the attention computation
|
||||||
queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
|
queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
|
||||||
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||||
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||||
@@ -93,35 +170,35 @@ private class MLP: Module, UnaryLayer {
|
|||||||
@ModuleInfo(key: "down_proj") var down: Linear
|
@ModuleInfo(key: "down_proj") var down: Linear
|
||||||
@ModuleInfo(key: "up_proj") var up: Linear
|
@ModuleInfo(key: "up_proj") var up: Linear
|
||||||
|
|
||||||
public init(dimensions: Int, hiddenDimensions: Int) {
|
init(_ args: LlamaConfiguration) {
|
||||||
self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
|
self._gate.wrappedValue = Linear(args.hiddenSize, args.intermediateSize, bias: args.mlpBias)
|
||||||
self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false)
|
self._down.wrappedValue = Linear(args.intermediateSize, args.hiddenSize, bias: args.mlpBias)
|
||||||
self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
|
self._up.wrappedValue = Linear(args.hiddenSize, args.intermediateSize, bias: args.mlpBias)
|
||||||
}
|
}
|
||||||
|
|
||||||
public func callAsFunction(_ x: MLXArray) -> MLXArray {
|
func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||||
down(silu(gate(x)) * up(x))
|
let activation = silu(gate(x))
|
||||||
|
return down(activation * up(x))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private class TransformerBlock: Module {
|
private class TransformerBlock: Module {
|
||||||
|
|
||||||
@ModuleInfo(key: "self_attn") var attention: Attention
|
@ModuleInfo(key: "self_attn") var attention: Attention
|
||||||
let mlp: MLP
|
@ModuleInfo(key: "mlp") var mlp: MLP
|
||||||
|
|
||||||
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
|
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
|
||||||
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
|
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
|
||||||
|
|
||||||
public init(_ args: LlamaConfiguration) {
|
init(_ args: LlamaConfiguration) {
|
||||||
self._attention.wrappedValue = Attention(args)
|
self._attention.wrappedValue = Attention(args)
|
||||||
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
|
self._mlp.wrappedValue = MLP(args)
|
||||||
self._inputLayerNorm.wrappedValue = RMSNorm(
|
self._inputLayerNorm.wrappedValue = RMSNorm(
|
||||||
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||||
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
|
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
|
||||||
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||||
}
|
}
|
||||||
|
|
||||||
public func callAsFunction(
|
func callAsFunction(
|
||||||
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
|
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
|
||||||
) -> (MLXArray, (MLXArray, MLXArray)) {
|
) -> (MLXArray, (MLXArray, MLXArray)) {
|
||||||
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
|
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
|
||||||
@@ -132,27 +209,24 @@ private class TransformerBlock: Module {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public class LlamaModelInner: Module {
|
private class LlamaModelInner: Module {
|
||||||
|
|
||||||
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
|
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
|
||||||
|
|
||||||
fileprivate let layers: [TransformerBlock]
|
let layers: [TransformerBlock]
|
||||||
let norm: RMSNorm
|
let norm: RMSNorm
|
||||||
|
|
||||||
public init(_ args: LlamaConfiguration) {
|
init(_ args: LlamaConfiguration) {
|
||||||
precondition(args.vocabularySize > 0)
|
precondition(args.vocabularySize > 0)
|
||||||
|
|
||||||
self._embedTokens.wrappedValue = Embedding(
|
self._embedTokens.wrappedValue = Embedding(
|
||||||
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
|
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
|
||||||
|
|
||||||
self.layers = (0 ..< args.hiddenLayers)
|
self.layers = (0 ..< args.hiddenLayers).map { _ in TransformerBlock(args) }
|
||||||
.map { _ in
|
|
||||||
TransformerBlock(args)
|
|
||||||
}
|
|
||||||
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||||
}
|
}
|
||||||
|
|
||||||
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
|
func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
|
||||||
MLXArray, [(MLXArray, MLXArray)]
|
MLXArray, [(MLXArray, MLXArray)]
|
||||||
) {
|
) {
|
||||||
var h = embedTokens(inputs)
|
var h = embedTokens(inputs)
|
||||||
@@ -178,7 +252,7 @@ public class LlamaModelInner: Module {
|
|||||||
public class LlamaModel: Module, LLMModel {
|
public class LlamaModel: Module, LLMModel {
|
||||||
|
|
||||||
public let vocabularySize: Int
|
public let vocabularySize: Int
|
||||||
let model: LlamaModelInner
|
fileprivate let model: LlamaModelInner
|
||||||
|
|
||||||
@ModuleInfo(key: "lm_head") var lmHead: Linear?
|
@ModuleInfo(key: "lm_head") var lmHead: Linear?
|
||||||
|
|
||||||
@@ -202,7 +276,7 @@ public class LlamaModel: Module, LLMModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
|
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
|
||||||
// Remove unused precomputed rotary freqs
|
// Remove unused precomputed rotary frequencies
|
||||||
weights.filter {
|
weights.filter {
|
||||||
!$0.key.contains("self_attn.rotary_emb.inv_freq")
|
!$0.key.contains("self_attn.rotary_emb.inv_freq")
|
||||||
}
|
}
|
||||||
@@ -215,14 +289,17 @@ public struct LlamaConfiguration: Codable {
|
|||||||
var hiddenLayers: Int
|
var hiddenLayers: Int
|
||||||
var intermediateSize: Int
|
var intermediateSize: Int
|
||||||
var attentionHeads: Int
|
var attentionHeads: Int
|
||||||
var headDimensions: Int? = nil
|
var headDimensions: Int?
|
||||||
var rmsNormEps: Float
|
var rmsNormEps: Float
|
||||||
var vocabularySize: Int
|
var vocabularySize: Int
|
||||||
var kvHeads: Int
|
var kvHeads: Int
|
||||||
|
var maxPositionEmbeddings: Int?
|
||||||
var ropeTheta: Float = 10_000
|
var ropeTheta: Float = 10_000
|
||||||
var ropeTraditional: Bool = false
|
var ropeTraditional: Bool = false
|
||||||
var ropeScaling: [String: StringOrNumber]? = nil
|
var ropeScaling: [String: StringOrNumber]?
|
||||||
var tieWordEmbeddings: Bool = false
|
var tieWordEmbeddings: Bool = true
|
||||||
|
var attentionBias: Bool = false
|
||||||
|
var mlpBias: Bool = false
|
||||||
|
|
||||||
enum CodingKeys: String, CodingKey {
|
enum CodingKeys: String, CodingKey {
|
||||||
case hiddenSize = "hidden_size"
|
case hiddenSize = "hidden_size"
|
||||||
@@ -233,45 +310,75 @@ public struct LlamaConfiguration: Codable {
|
|||||||
case rmsNormEps = "rms_norm_eps"
|
case rmsNormEps = "rms_norm_eps"
|
||||||
case vocabularySize = "vocab_size"
|
case vocabularySize = "vocab_size"
|
||||||
case kvHeads = "num_key_value_heads"
|
case kvHeads = "num_key_value_heads"
|
||||||
|
case maxPositionEmbeddings = "max_position_embeddings"
|
||||||
case ropeTheta = "rope_theta"
|
case ropeTheta = "rope_theta"
|
||||||
case ropeTraditional = "rope_traditional"
|
case ropeTraditional = "rope_traditional"
|
||||||
case ropeScaling = "rope_scaling"
|
case ropeScaling = "rope_scaling"
|
||||||
case tieWordEmbeddings = "tie_word_embeddings"
|
case tieWordEmbeddings = "tie_word_embeddings"
|
||||||
|
case attentionBias = "attention_bias"
|
||||||
|
case mlpBias = "mlp_bias"
|
||||||
}
|
}
|
||||||
|
|
||||||
public init(from decoder: Decoder) throws {
|
public init(from decoder: Decoder) throws {
|
||||||
// custom implementation to handle optional keys with required values
|
let container = try decoder.container(keyedBy: CodingKeys.self)
|
||||||
let container: KeyedDecodingContainer<LlamaConfiguration.CodingKeys> =
|
|
||||||
try decoder.container(
|
|
||||||
keyedBy: LlamaConfiguration.CodingKeys.self)
|
|
||||||
|
|
||||||
self.hiddenSize = try container.decode(
|
hiddenSize = try container.decode(Int.self, forKey: .hiddenSize)
|
||||||
Int.self, forKey: LlamaConfiguration.CodingKeys.hiddenSize)
|
hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers)
|
||||||
self.hiddenLayers = try container.decode(
|
intermediateSize = try container.decode(Int.self, forKey: .intermediateSize)
|
||||||
Int.self, forKey: LlamaConfiguration.CodingKeys.hiddenLayers)
|
attentionHeads = try container.decode(Int.self, forKey: .attentionHeads)
|
||||||
self.intermediateSize = try container.decode(
|
headDimensions = try container.decodeIfPresent(Int.self, forKey: .headDimensions)
|
||||||
Int.self, forKey: LlamaConfiguration.CodingKeys.intermediateSize)
|
rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps)
|
||||||
self.attentionHeads = try container.decode(
|
vocabularySize = try container.decode(Int.self, forKey: .vocabularySize)
|
||||||
Int.self, forKey: LlamaConfiguration.CodingKeys.attentionHeads)
|
kvHeads = try container.decodeIfPresent(Int.self, forKey: .kvHeads) ?? attentionHeads
|
||||||
self.headDimensions = try container.decodeIfPresent(
|
maxPositionEmbeddings = try container.decodeIfPresent(
|
||||||
Int.self, forKey: LlamaConfiguration.CodingKeys.headDimensions)
|
Int.self, forKey: .maxPositionEmbeddings)
|
||||||
self.rmsNormEps = try container.decode(
|
if let ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) {
|
||||||
Float.self, forKey: LlamaConfiguration.CodingKeys.rmsNormEps)
|
self.ropeTheta = ropeTheta
|
||||||
self.vocabularySize = try container.decode(
|
}
|
||||||
Int.self, forKey: LlamaConfiguration.CodingKeys.vocabularySize)
|
if let ropeTraditional = try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional)
|
||||||
self.kvHeads = try container.decode(Int.self, forKey: LlamaConfiguration.CodingKeys.kvHeads)
|
{
|
||||||
self.ropeTheta =
|
self.ropeTraditional = ropeTraditional
|
||||||
try container.decodeIfPresent(
|
}
|
||||||
Float.self, forKey: LlamaConfiguration.CodingKeys.ropeTheta)
|
ropeScaling = try container.decodeIfPresent(
|
||||||
?? 10_000
|
[String: StringOrNumber].self, forKey: .ropeScaling)
|
||||||
self.ropeTraditional =
|
if let tieWordEmbeddings = try container.decodeIfPresent(
|
||||||
try container.decodeIfPresent(
|
Bool.self, forKey: .tieWordEmbeddings)
|
||||||
Bool.self, forKey: LlamaConfiguration.CodingKeys.ropeTraditional) ?? false
|
{
|
||||||
self.ropeScaling = try container.decodeIfPresent(
|
self.tieWordEmbeddings = tieWordEmbeddings
|
||||||
[String: StringOrNumber].self, forKey: LlamaConfiguration.CodingKeys.ropeScaling)
|
}
|
||||||
self.tieWordEmbeddings =
|
if let attentionBias = try container.decodeIfPresent(Bool.self, forKey: .attentionBias) {
|
||||||
try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false
|
self.attentionBias = attentionBias
|
||||||
|
}
|
||||||
|
if let mlpBias = try container.decodeIfPresent(Bool.self, forKey: .mlpBias) {
|
||||||
|
self.mlpBias = mlpBias
|
||||||
|
}
|
||||||
|
|
||||||
|
if let ropeScaling {
|
||||||
|
if ropeScaling["factor"] == nil {
|
||||||
|
throw DecodingError.dataCorruptedError(
|
||||||
|
forKey: .ropeScaling, in: container,
|
||||||
|
debugDescription: "rope_scaling must contain 'factor'")
|
||||||
|
}
|
||||||
|
if let ropeType = ropeScaling["type"] ?? ropeScaling["rope_type"] {
|
||||||
|
if case .string = ropeType {
|
||||||
|
let options = [
|
||||||
|
StringOrNumber.string("linear"), StringOrNumber.string("dynamic"),
|
||||||
|
StringOrNumber.string("llama3"),
|
||||||
|
]
|
||||||
|
if !options.contains(ropeType) {
|
||||||
|
throw DecodingError.dataCorruptedError(
|
||||||
|
forKey: .ropeScaling, in: container,
|
||||||
|
debugDescription:
|
||||||
|
"rope_scaling 'type' currently only supports 'linear', 'dynamic', or 'llama3'"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw DecodingError.dataCorruptedError(
|
||||||
|
forKey: .ropeScaling, in: container,
|
||||||
|
debugDescription: "rope_scaling must contain either 'type' or 'rope_type'")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ extension ModelConfiguration {
|
|||||||
defaultPrompt: "Why is the sky blue?"
|
defaultPrompt: "Why is the sky blue?"
|
||||||
)
|
)
|
||||||
|
|
||||||
public static let phi34bit = ModelConfiguration(
|
public static let phi3_4bit = ModelConfiguration(
|
||||||
id: "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed",
|
id: "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed",
|
||||||
defaultPrompt: "what is the gravity on mars and the moon?",
|
defaultPrompt: "what is the gravity on mars and the moon?",
|
||||||
extraEOSTokens: ["<|end|>"]
|
extraEOSTokens: ["<|end|>"]
|
||||||
@@ -199,9 +199,17 @@ extension ModelConfiguration {
|
|||||||
"\(prompt)"
|
"\(prompt)"
|
||||||
}
|
}
|
||||||
|
|
||||||
public static let llama38B4bit = ModelConfiguration(
|
public static let llama3_1_8B_4bit = ModelConfiguration(
|
||||||
|
id: "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
|
||||||
|
defaultPrompt: "What is the difference between a fruit and a vegetable?"
|
||||||
|
) {
|
||||||
|
prompt in
|
||||||
|
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
|
||||||
|
}
|
||||||
|
|
||||||
|
public static let llama3_8B_4bit = ModelConfiguration(
|
||||||
id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
|
id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
|
||||||
defaultPrompt: "what is the difference between a fruit and a vegetable?"
|
defaultPrompt: "What is the difference between a fruit and a vegetable?"
|
||||||
) {
|
) {
|
||||||
prompt in
|
prompt in
|
||||||
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
|
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
|
||||||
@@ -220,12 +228,13 @@ extension ModelConfiguration {
|
|||||||
case .idle:
|
case .idle:
|
||||||
bootstrapState = .bootstrapping
|
bootstrapState = .bootstrapping
|
||||||
register(configurations: [
|
register(configurations: [
|
||||||
|
llama3_1_8B_4bit,
|
||||||
mistralNeMo4bit,
|
mistralNeMo4bit,
|
||||||
smolLM_135M_4bit,
|
smolLM_135M_4bit,
|
||||||
mistral7B4bit,
|
mistral7B4bit,
|
||||||
codeLlama13b4bit,
|
codeLlama13b4bit,
|
||||||
phi4bit,
|
phi4bit,
|
||||||
phi34bit,
|
phi3_4bit,
|
||||||
gemma2bQuantized,
|
gemma2bQuantized,
|
||||||
gemma_2_9b_it_4bit,
|
gemma_2_9b_it_4bit,
|
||||||
qwen205b4bit,
|
qwen205b4bit,
|
||||||
|
|||||||
Reference in New Issue
Block a user