From 3aa9cf3f9ed7e1dd508b0d98b07834f5ac5c43cf Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 13 May 2025 14:27:53 -0700 Subject: [PATCH] Fix put_along_axis for empty arrays (#2181) --- mlx/ops.cpp | 4 ++++ python/tests/test_ops.py | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 922680110..0c18cccfe 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3175,6 +3175,10 @@ array scatter_axis( throw std::invalid_argument(msg.str()); } + if (a.size() == 0) { + return a; + } + auto upd = astype(values, a.dtype(), s); // Squeeze leading singletons out of update diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 0921de788..f3d48dda3 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1255,6 +1255,12 @@ class TestOps(mlx_tests.MLXTestCase): np.put_along_axis(out_np, np.array(indices), np.array(update), axis=-2) self.assertTrue(np.array_equal(out_np, np.array(out_mlx))) + a = mx.array([], mx.float32) + b = mx.put_along_axis(a, a, a, axis=None) + mx.eval(b) + self.assertEqual(b.size, 0) + self.assertEqual(b.shape, a.shape) + def test_split(self): a = mx.array([1, 2, 3]) splits = mx.split(a, 3)