mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-14 20:41:13 +08:00
Change the variance computation (#319)
This commit is contained in:
parent
0de5988f92
commit
199aebcf77
15
mlx/ops.cpp
15
mlx/ops.cpp
@ -1337,11 +1337,18 @@ array var(
|
|||||||
bool keepdims /* = false */,
|
bool keepdims /* = false */,
|
||||||
int ddof /* = 0*/,
|
int ddof /* = 0*/,
|
||||||
StreamOrDevice s /* = {}*/) {
|
StreamOrDevice s /* = {}*/) {
|
||||||
auto nelements = compute_number_of_elements(a, axes);
|
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
auto mu = mean(a, axes, true, s);
|
auto mu2 = square(mean(a, axes, keepdims, s), s);
|
||||||
auto S = sum(square(subtract(a, mu, s), s), axes, keepdims, s);
|
auto a2 = mean(square(a, s), axes, keepdims, s);
|
||||||
return multiply(S, array(1.0 / (nelements - ddof), dtype), s);
|
auto v = subtract(a2, mu2, s);
|
||||||
|
|
||||||
|
if (ddof != 0) {
|
||||||
|
auto nelements = compute_number_of_elements(a, axes);
|
||||||
|
float factor = nelements / (nelements - ddof);
|
||||||
|
v = multiply(v, array(factor, dtype), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
return v;
|
||||||
}
|
}
|
||||||
|
|
||||||
array var(
|
array var(
|
||||||
|
Loading…
Reference in New Issue
Block a user