Revert variance to be numerically stable (#1314)

This commit is contained in:
Angelos Katharopoulos 2024-08-08 13:35:02 -07:00 committed by GitHub
parent 30bbea2f08
commit eb8819e91e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1640,17 +1640,21 @@ array var(
int ddof /* = 0*/, int ddof /* = 0*/,
StreamOrDevice s /* = {}*/) { StreamOrDevice s /* = {}*/) {
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
auto mu2 = square(mean(a, axes, keepdims, s), s); auto mu = mean(a, axes, /* keepdims= */ true, s);
auto a2 = mean(square(a, s), axes, keepdims, s); auto v = sum(square(subtract(a, mu, s), s), axes, keepdims, s);
auto v = subtract(a2, mu2, s);
if (ddof != 0) { if (ddof != 0) {
auto nelements = number_of_elements(a, axes, false, dtype, s); auto normalizer = maximum(
auto factor = divide( subtract(
nelements, number_of_elements(a, axes, false, dtype, s),
maximum(subtract(nelements, array(ddof, dtype), s), array(0, dtype), s), array(ddof, dtype),
s),
array(0, dtype),
s); s);
v = multiply(v, factor, s); v = divide(v, normalizer, s);
} else {
auto normalizer = number_of_elements(a, axes, true, dtype, s);
v = multiply(v, normalizer, s);
} }
return v; return v;