committed by
GitHub
parent
a2e8d7e469
commit
c4fda0e036
@@ -26,7 +26,7 @@ private class Attention: Module {
|
||||
let heads = args.attentionHeads
|
||||
let kvHeads = args.kvHeads
|
||||
|
||||
let headDim = args.hiddenSize / heads
|
||||
let headDim = args.headDimensions ?? (args.hiddenSize / heads)
|
||||
self.scale = pow(Float(headDim), -0.5)
|
||||
|
||||
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false)
|
||||
@@ -215,6 +215,7 @@ public struct LlamaConfiguration: Codable {
|
||||
var hiddenLayers: Int
|
||||
var intermediateSize: Int
|
||||
var attentionHeads: Int
|
||||
var headDimensions: Int? = nil
|
||||
var rmsNormEps: Float
|
||||
var vocabularySize: Int
|
||||
var kvHeads: Int
|
||||
@@ -228,6 +229,7 @@ public struct LlamaConfiguration: Codable {
|
||||
case hiddenLayers = "num_hidden_layers"
|
||||
case intermediateSize = "intermediate_size"
|
||||
case attentionHeads = "num_attention_heads"
|
||||
case headDimensions = "head_dim"
|
||||
case rmsNormEps = "rms_norm_eps"
|
||||
case vocabularySize = "vocab_size"
|
||||
case kvHeads = "num_key_value_heads"
|
||||
@@ -251,6 +253,8 @@ public struct LlamaConfiguration: Codable {
|
||||
Int.self, forKey: LlamaConfiguration.CodingKeys.intermediateSize)
|
||||
self.attentionHeads = try container.decode(
|
||||
Int.self, forKey: LlamaConfiguration.CodingKeys.attentionHeads)
|
||||
self.headDimensions = try container.decodeIfPresent(
|
||||
Int.self, forKey: LlamaConfiguration.CodingKeys.headDimensions)
|
||||
self.rmsNormEps = try container.decode(
|
||||
Float.self, forKey: LlamaConfiguration.CodingKeys.rmsNormEps)
|
||||
self.vocabularySize = try container.decode(
|
||||
|
||||
Reference in New Issue
Block a user