Use fast (#38)

* update to latest mlx swift and use fast norms
* gpu usage -> memory usage
This commit is contained in:
Awni Hannun
2024-03-27 16:37:35 -07:00
committed by GitHub
parent 3314e20a24
commit 15b38cd146
4 changed files with 5 additions and 17 deletions

View File

@@ -9,7 +9,6 @@ import MLXNN
// specialized norm for gemma
private class RMSNorm: Module, UnaryLayer {
let weight: MLXArray
let eps: Float
@@ -19,13 +18,8 @@ private class RMSNorm: Module, UnaryLayer {
super.init()
}
func norm(_ x: MLXArray) -> MLXArray {
x * rsqrt(x.square().mean(axis: -1, keepDims: true) + eps)
}
public func callAsFunction(_ x: MLXArray) -> MLXArray {
let output = norm(x.asType(Float.self)).asType(x.dtype)
return (1 + weight) * output
return MLXFast.rmsNorm(x, weight: 1.0 + self.weight, eps: self.eps)
}
}