From eb8819e91e8b06fcb6e8aeacecea93add589c3da Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 8 Aug 2024 13:35:02 -0700 Subject: [PATCH] Revert variance to be numerically stable (#1314) --- mlx/ops.cpp | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 53f612642..aecf3a65c 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1640,17 +1640,21 @@ array var( int ddof /* = 0*/, StreamOrDevice s /* = {}*/) { auto dtype = at_least_float(a.dtype()); - auto mu2 = square(mean(a, axes, keepdims, s), s); - auto a2 = mean(square(a, s), axes, keepdims, s); - auto v = subtract(a2, mu2, s); + auto mu = mean(a, axes, /* keepdims= */ true, s); + auto v = sum(square(subtract(a, mu, s), s), axes, keepdims, s); if (ddof != 0) { - auto nelements = number_of_elements(a, axes, false, dtype, s); - auto factor = divide( - nelements, - maximum(subtract(nelements, array(ddof, dtype), s), array(0, dtype), s), + auto normalizer = maximum( + subtract( + number_of_elements(a, axes, false, dtype, s), + array(ddof, dtype), + s), + array(0, dtype), s); - v = multiply(v, factor, s); + v = divide(v, normalizer, s); + } else { + auto normalizer = number_of_elements(a, axes, true, dtype, s); + v = multiply(v, normalizer, s); } return v;