mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-14 12:31:13 +08:00
Change the variance computation (#319)
This commit is contained in:
parent
0de5988f92
commit
199aebcf77
15
mlx/ops.cpp
15
mlx/ops.cpp
@ -1337,11 +1337,18 @@ array var(
|
||||
bool keepdims /* = false */,
|
||||
int ddof /* = 0*/,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
auto nelements = compute_number_of_elements(a, axes);
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
auto mu = mean(a, axes, true, s);
|
||||
auto S = sum(square(subtract(a, mu, s), s), axes, keepdims, s);
|
||||
return multiply(S, array(1.0 / (nelements - ddof), dtype), s);
|
||||
auto mu2 = square(mean(a, axes, keepdims, s), s);
|
||||
auto a2 = mean(square(a, s), axes, keepdims, s);
|
||||
auto v = subtract(a2, mu2, s);
|
||||
|
||||
if (ddof != 0) {
|
||||
auto nelements = compute_number_of_elements(a, axes);
|
||||
float factor = nelements / (nelements - ddof);
|
||||
v = multiply(v, array(factor, dtype), s);
|
||||
}
|
||||
|
||||
return v;
|
||||
}
|
||||
|
||||
array var(
|
||||
|
Loading…
Reference in New Issue
Block a user