mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Change layernorms to two pass algorithm (#2246)
This commit is contained in:
committed by
GitHub
parent
24f89173d1
commit
2e8cf0b450
@@ -231,13 +231,11 @@ array layer_norm(
|
||||
const std::vector<array>& inputs) {
|
||||
auto x = astype(inputs[0], float32, s);
|
||||
|
||||
// Should I not be smart here and leave the double mean to simplify()?
|
||||
auto mu = mean(x, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
auto mu2 = square(mu, s);
|
||||
auto x2 = mean(square(x, s), /* axis= */ -1, /* keepdims= */ true, s);
|
||||
auto v = subtract(x2, mu2, s);
|
||||
auto xc = subtract(x, mu, s);
|
||||
auto v = mean(square(xc, s), /* axis= */ -1, /* keepdims= */ true, s);
|
||||
|
||||
x = multiply(subtract(x, mu, s), rsqrt(add(v, array(eps, float32), s), s));
|
||||
x = multiply(xc, rsqrt(add(v, array(eps, float32), s), s));
|
||||
x = astype(x, out_type, s);
|
||||
|
||||
// If the LN is affine then transform x according to the weight and bias
|
||||
|
||||
Reference in New Issue
Block a user