From 199aebcf77d5342496947df2a7b7570245cc6218 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 30 Jan 2024 19:28:56 -0800 Subject: [PATCH] Change the variance computation (#319) --- mlx/ops.cpp | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 31324f3cc..b97bf5621 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1337,11 +1337,18 @@ array var( bool keepdims /* = false */, int ddof /* = 0*/, StreamOrDevice s /* = {}*/) { - auto nelements = compute_number_of_elements(a, axes); auto dtype = at_least_float(a.dtype()); - auto mu = mean(a, axes, true, s); - auto S = sum(square(subtract(a, mu, s), s), axes, keepdims, s); - return multiply(S, array(1.0 / (nelements - ddof), dtype), s); + auto mu2 = square(mean(a, axes, keepdims, s), s); + auto a2 = mean(square(a, s), axes, keepdims, s); + auto v = subtract(a2, mu2, s); + + if (ddof != 0) { + auto nelements = compute_number_of_elements(a, axes); + float factor = nelements / (nelements - ddof); + v = multiply(v, array(factor, dtype), s); + } + + return v; } array var(