diff --git a/Libraries/LLM/Phi.swift b/Libraries/LLM/Phi.swift index 9a68e29..4c9763b 100644 --- a/Libraries/LLM/Phi.swift +++ b/Libraries/LLM/Phi.swift @@ -73,14 +73,12 @@ private class PhiAttention: Module { keys = rope(keys) } - queries = queries.asType(Float.self) - keys = keys.asType(Float.self) - // Finally perform the attention computation let scale = sqrt(1 / Float(queries.dim(-1))) let output = MLXFast.scaledDotProductAttention( - queries: queries, keys: keys, values: values, scale: scale, mask: mask + queries: queries.asType(.float32), keys: keys, values: values, scale: scale, mask: mask ) + .asType(values.dtype) .transposed(0, 2, 1, 3) .reshaped(B, L, -1)