Change the variance computation (#319)

This commit is contained in:
Angelos Katharopoulos 2024-01-30 19:28:56 -08:00 committed by GitHub
parent 0de5988f92
commit 199aebcf77
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1337,11 +1337,18 @@ array var(
bool keepdims /* = false */,
int ddof /* = 0*/,
StreamOrDevice s /* = {}*/) {
auto nelements = compute_number_of_elements(a, axes);
auto dtype = at_least_float(a.dtype());
auto mu = mean(a, axes, true, s);
auto S = sum(square(subtract(a, mu, s), s), axes, keepdims, s);
return multiply(S, array(1.0 / (nelements - ddof), dtype), s);
auto mu2 = square(mean(a, axes, keepdims, s), s);
auto a2 = mean(square(a, s), axes, keepdims, s);
auto v = subtract(a2, mu2, s);
if (ddof != 0) {
auto nelements = compute_number_of_elements(a, axes);
float factor = nelements / (nelements - ddof);
v = multiply(v, array(factor, dtype), s);
}
return v;
}
array var(