Change layernorms to two pass algorithm (#2246)

This commit is contained in:
Angelos Katharopoulos
2025-06-06 13:34:56 -07:00
committed by GitHub
parent 24f89173d1
commit 2e8cf0b450
5 changed files with 260 additions and 306 deletions

View File

@@ -231,13 +231,11 @@ array layer_norm(
const std::vector<array>& inputs) {
auto x = astype(inputs[0], float32, s);
// Should I not be smart here and leave the double mean to simplify()?
auto mu = mean(x, /* axis= */ -1, /* keepdims= */ true, s);
auto mu2 = square(mu, s);
auto x2 = mean(square(x, s), /* axis= */ -1, /* keepdims= */ true, s);
auto v = subtract(x2, mu2, s);
auto xc = subtract(x, mu, s);
auto v = mean(square(xc, s), /* axis= */ -1, /* keepdims= */ true, s);
x = multiply(subtract(x, mu, s), rsqrt(add(v, array(eps, float32), s), s));
x = multiply(xc, rsqrt(add(v, array(eps, float32), s), s));
x = astype(x, out_type, s);
// If the LN is affine then transform x according to the weight and bias