fix float types in Phi (use float16) (#25)

- per suggestions in #23 ensure that the values that go into the cache are float16
This commit is contained in:
David Koski
2024-03-14 13:18:40 -07:00
committed by GitHub
parent a1431e7155
commit ac273a14ea

View File

@@ -73,14 +73,12 @@ private class PhiAttention: Module {
keys = rope(keys)
}
queries = queries.asType(Float.self)
keys = keys.asType(Float.self)
// Finally perform the attention computation
let scale = sqrt(1 / Float(queries.dim(-1)))
let output = MLXFast.scaledDotProductAttention(
queries: queries, keys: keys, values: values, scale: scale, mask: mask
queries: queries.asType(.float32), keys: keys, values: values, scale: scale, mask: mask
)
.asType(values.dtype)
.transposed(0, 2, 1, 3)
.reshaped(B, L, -1)