fix float types in Phi (use float16) (#25)
- per suggestions in #23 ensure that the values that go into the cache are float16
This commit is contained in:
@@ -73,14 +73,12 @@ private class PhiAttention: Module {
|
|||||||
keys = rope(keys)
|
keys = rope(keys)
|
||||||
}
|
}
|
||||||
|
|
||||||
queries = queries.asType(Float.self)
|
|
||||||
keys = keys.asType(Float.self)
|
|
||||||
|
|
||||||
// Finally perform the attention computation
|
// Finally perform the attention computation
|
||||||
let scale = sqrt(1 / Float(queries.dim(-1)))
|
let scale = sqrt(1 / Float(queries.dim(-1)))
|
||||||
let output = MLXFast.scaledDotProductAttention(
|
let output = MLXFast.scaledDotProductAttention(
|
||||||
queries: queries, keys: keys, values: values, scale: scale, mask: mask
|
queries: queries.asType(.float32), keys: keys, values: values, scale: scale, mask: mask
|
||||||
)
|
)
|
||||||
|
.asType(values.dtype)
|
||||||
.transposed(0, 2, 1, 3)
|
.transposed(0, 2, 1, 3)
|
||||||
.reshaped(B, L, -1)
|
.reshaped(B, L, -1)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user