mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Revert variance to be numerically stable (#1314)
This commit is contained in:
parent
30bbea2f08
commit
eb8819e91e
20
mlx/ops.cpp
20
mlx/ops.cpp
@ -1640,17 +1640,21 @@ array var(
|
||||
int ddof /* = 0*/,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
auto mu2 = square(mean(a, axes, keepdims, s), s);
|
||||
auto a2 = mean(square(a, s), axes, keepdims, s);
|
||||
auto v = subtract(a2, mu2, s);
|
||||
auto mu = mean(a, axes, /* keepdims= */ true, s);
|
||||
auto v = sum(square(subtract(a, mu, s), s), axes, keepdims, s);
|
||||
|
||||
if (ddof != 0) {
|
||||
auto nelements = number_of_elements(a, axes, false, dtype, s);
|
||||
auto factor = divide(
|
||||
nelements,
|
||||
maximum(subtract(nelements, array(ddof, dtype), s), array(0, dtype), s),
|
||||
auto normalizer = maximum(
|
||||
subtract(
|
||||
number_of_elements(a, axes, false, dtype, s),
|
||||
array(ddof, dtype),
|
||||
s),
|
||||
array(0, dtype),
|
||||
s);
|
||||
v = multiply(v, factor, s);
|
||||
v = divide(v, normalizer, s);
|
||||
} else {
|
||||
auto normalizer = number_of_elements(a, axes, true, dtype, s);
|
||||
v = multiply(v, normalizer, s);
|
||||
}
|
||||
|
||||
return v;
|
||||
|
Loading…
Reference in New Issue
Block a user