From 969337345ffd34159884c8bf124dbb6eeaa902bb Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sun, 1 Sep 2024 21:37:51 -0700 Subject: [PATCH] Fix reduce edge case (#1389) --- mlx/backend/common/reduce_utils.cpp | 2 +- python/tests/test_reduce.py | 30 ++++++++++++++--------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/mlx/backend/common/reduce_utils.cpp b/mlx/backend/common/reduce_utils.cpp index b15bc9e18..64049873e 100644 --- a/mlx/backend/common/reduce_utils.cpp +++ b/mlx/backend/common/reduce_utils.cpp @@ -32,7 +32,7 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector& axes) { std::vector shape = {x.shape(axes[0])}; std::vector strides = {x.strides()[axes[0]]}; for (int i = 1; i < axes.size(); i++) { - if (axes[i] - 1 == axes[i - 1]) { + if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) { shape.back() *= x.shape(axes[i]); strides.back() = x.strides()[axes[i]]; } else { diff --git a/python/tests/test_reduce.py b/python/tests/test_reduce.py index c59e7f490..368b2128a 100644 --- a/python/tests/test_reduce.py +++ b/python/tests/test_reduce.py @@ -10,21 +10,21 @@ import numpy as np class TestReduce(mlx_tests.MLXTestCase): def test_axis_permutation_sums(self): - x_npy = np.random.randn(5, 5, 5, 5, 5).astype(np.float32) - x_mlx = mx.array(x_npy) - for t in permutations(range(5)): - with self.subTest(t=t): - y_npy = np.transpose(x_npy, t) - y_mlx = mx.transpose(x_mlx, t) - for n in range(1, 6): - for a in combinations(range(5), n): - with self.subTest(a=a): - z_npy = np.sum(y_npy, axis=a) - z_mlx = mx.sum(y_mlx, axis=a) - mx.eval(z_mlx) - self.assertTrue( - np.allclose(z_npy, np.array(z_mlx), atol=1e-4) - ) + for shape in [(5, 5, 1, 5, 5), (65, 65, 1, 65)]: + with self.subTest(shape=shape): + x_npy = (np.random.randn(*shape) * 128).astype(np.int32) + x_mlx = mx.array(x_npy) + for t in permutations(range(len(shape))): + with self.subTest(t=t): + y_npy = np.transpose(x_npy, t) + y_mlx = mx.transpose(x_mlx, t) + for n in range(1, len(shape) + 1): + for a in combinations(range(len(shape)), n): + with self.subTest(a=a): + z_npy = np.sum(y_npy, axis=a) + z_mlx = mx.sum(y_mlx, axis=a) + mx.eval(z_mlx) + self.assertTrue(np.all(z_npy == z_mlx)) def test_expand_sums(self): x_npy = np.random.randn(5, 1, 5, 1, 5, 1).astype(np.float32)