adopt MLXFast.scaledDotProductAttention (#23)
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
import Foundation
|
||||
import MLX
|
||||
import MLXFast
|
||||
import MLXNN
|
||||
|
||||
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py
|
||||
@@ -9,7 +10,6 @@ import MLXNN
|
||||
private class Attention: Module {
|
||||
|
||||
let args: LlamaConfiguration
|
||||
let repeats: Int
|
||||
let scale: Float
|
||||
|
||||
@ModuleInfo(key: "q_proj") var wq: Linear
|
||||
@@ -26,8 +26,6 @@ private class Attention: Module {
|
||||
let heads = args.attentionHeads
|
||||
let kvHeads = args.kvHeads
|
||||
|
||||
self.repeats = heads / kvHeads
|
||||
|
||||
let headDim = args.hiddenSize / heads
|
||||
self.scale = pow(Float(headDim), -0.5)
|
||||
|
||||
@@ -69,11 +67,6 @@ private class Attention: Module {
|
||||
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||
|
||||
if repeats > 1 {
|
||||
keys = MLXArray.repeat(keys, count: repeats, axis: 1)
|
||||
values = MLXArray.repeat(values, count: repeats, axis: 1)
|
||||
}
|
||||
|
||||
if let (keyCache, valueCache) = cache {
|
||||
queries = rope(queries, offset: keyCache.dim(2))
|
||||
keys = rope(keys, offset: keyCache.dim(2))
|
||||
@@ -84,14 +77,11 @@ private class Attention: Module {
|
||||
keys = rope(keys)
|
||||
}
|
||||
|
||||
var scores = (queries * self.scale).matmul(keys.transposed(0, 1, 3, 2))
|
||||
if let mask {
|
||||
scores = scores + mask
|
||||
}
|
||||
|
||||
scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype)
|
||||
|
||||
let output = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1)
|
||||
let output = MLXFast.scaledDotProductAttention(
|
||||
queries: queries, keys: keys, values: values, scale: scale, mask: mask
|
||||
)
|
||||
.transposed(0, 2, 1, 3)
|
||||
.reshaped(B, L, -1)
|
||||
|
||||
return (wo(output), (keys, values))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user