Use fast (#38)
* update to latest mlx swift and use fast norms * gpu usage -> memory usage
This commit is contained in:
@@ -9,7 +9,6 @@ import MLXNN
|
||||
|
||||
// specialized norm for gemma
|
||||
private class RMSNorm: Module, UnaryLayer {
|
||||
|
||||
let weight: MLXArray
|
||||
let eps: Float
|
||||
|
||||
@@ -19,13 +18,8 @@ private class RMSNorm: Module, UnaryLayer {
|
||||
super.init()
|
||||
}
|
||||
|
||||
func norm(_ x: MLXArray) -> MLXArray {
|
||||
x * rsqrt(x.square().mean(axis: -1, keepDims: true) + eps)
|
||||
}
|
||||
|
||||
public func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||
let output = norm(x.asType(Float.self)).asType(x.dtype)
|
||||
return (1 + weight) * output
|
||||
return MLXFast.rmsNorm(x, weight: 1.0 + self.weight, eps: self.eps)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user