diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 53f612642..aecf3a65c 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1640,17 +1640,21 @@ array var( int ddof /* = 0*/, StreamOrDevice s /* = {}*/) { auto dtype = at_least_float(a.dtype()); - auto mu2 = square(mean(a, axes, keepdims, s), s); - auto a2 = mean(square(a, s), axes, keepdims, s); - auto v = subtract(a2, mu2, s); + auto mu = mean(a, axes, /* keepdims= */ true, s); + auto v = sum(square(subtract(a, mu, s), s), axes, keepdims, s); if (ddof != 0) { - auto nelements = number_of_elements(a, axes, false, dtype, s); - auto factor = divide( - nelements, - maximum(subtract(nelements, array(ddof, dtype), s), array(0, dtype), s), + auto normalizer = maximum( + subtract( + number_of_elements(a, axes, false, dtype, s), + array(ddof, dtype), + s), + array(0, dtype), s); - v = multiply(v, factor, s); + v = divide(v, normalizer, s); + } else { + auto normalizer = number_of_elements(a, axes, true, dtype, s); + v = multiply(v, normalizer, s); } return v;