Add Mistral NeMo (#97)

* Update Mistral 7B config

* Add Mistral NeMo
This commit is contained in:
Anthony DePasquale
2024-07-25 17:01:26 +02:00
committed by GitHub
parent a2e8d7e469
commit c4fda0e036
2 changed files with 19 additions and 6 deletions

View File

@@ -26,7 +26,7 @@ private class Attention: Module {
let heads = args.attentionHeads
let kvHeads = args.kvHeads
let headDim = args.hiddenSize / heads
let headDim = args.headDimensions ?? (args.hiddenSize / heads)
self.scale = pow(Float(headDim), -0.5)
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false)
@@ -215,6 +215,7 @@ public struct LlamaConfiguration: Codable {
var hiddenLayers: Int
var intermediateSize: Int
var attentionHeads: Int
var headDimensions: Int? = nil
var rmsNormEps: Float
var vocabularySize: Int
var kvHeads: Int
@@ -228,6 +229,7 @@ public struct LlamaConfiguration: Codable {
case hiddenLayers = "num_hidden_layers"
case intermediateSize = "intermediate_size"
case attentionHeads = "num_attention_heads"
case headDimensions = "head_dim"
case rmsNormEps = "rms_norm_eps"
case vocabularySize = "vocab_size"
case kvHeads = "num_key_value_heads"
@@ -251,6 +253,8 @@ public struct LlamaConfiguration: Codable {
Int.self, forKey: LlamaConfiguration.CodingKeys.intermediateSize)
self.attentionHeads = try container.decode(
Int.self, forKey: LlamaConfiguration.CodingKeys.attentionHeads)
self.headDimensions = try container.decodeIfPresent(
Int.self, forKey: LlamaConfiguration.CodingKeys.headDimensions)
self.rmsNormEps = try container.decode(
Float.self, forKey: LlamaConfiguration.CodingKeys.rmsNormEps)
self.vocabularySize = try container.decode(