fix rmsnorm for gemma
This commit is contained in:
@@ -19,10 +19,7 @@ private class RMSNorm: Module, UnaryLayer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func norm(_ x: MLXArray) -> MLXArray {
|
func norm(_ x: MLXArray) -> MLXArray {
|
||||||
let S = 1.0 / sqrt(Float(x.dim(-1)))
|
x * rsqrt(x.square().mean(axis: -1, keepDims: true) + eps)
|
||||||
|
|
||||||
let n = (x * S).square().sum(axis: -1, keepDims: true)
|
|
||||||
return rsqrt(n + eps)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public func callAsFunction(_ x: MLXArray) -> MLXArray {
|
public func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||||
|
|||||||
Reference in New Issue
Block a user