fix float types in Phi (use float16) (#25)
- per suggestions in #23 ensure that the values that go into the cache are float16
This commit is contained in:
@@ -73,14 +73,12 @@ private class PhiAttention: Module {
|
||||
keys = rope(keys)
|
||||
}
|
||||
|
||||
queries = queries.asType(Float.self)
|
||||
keys = keys.asType(Float.self)
|
||||
|
||||
// Finally perform the attention computation
|
||||
let scale = sqrt(1 / Float(queries.dim(-1)))
|
||||
let output = MLXFast.scaledDotProductAttention(
|
||||
queries: queries, keys: keys, values: values, scale: scale, mask: mask
|
||||
queries: queries.asType(.float32), keys: keys, values: values, scale: scale, mask: mask
|
||||
)
|
||||
.asType(values.dtype)
|
||||
.transposed(0, 2, 1, 3)
|
||||
.reshaped(B, L, -1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user