diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 922680110..0c18cccfe 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3175,6 +3175,10 @@ array scatter_axis( throw std::invalid_argument(msg.str()); } + if (a.size() == 0) { + return a; + } + auto upd = astype(values, a.dtype(), s); // Squeeze leading singletons out of update diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 0921de788..f3d48dda3 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1255,6 +1255,12 @@ class TestOps(mlx_tests.MLXTestCase): np.put_along_axis(out_np, np.array(indices), np.array(update), axis=-2) self.assertTrue(np.array_equal(out_np, np.array(out_mlx))) + a = mx.array([], mx.float32) + b = mx.put_along_axis(a, a, a, axis=None) + mx.eval(b) + self.assertEqual(b.size, 0) + self.assertEqual(b.shape, a.shape) + def test_split(self): a = mx.array([1, 2, 3]) splits = mx.split(a, 3)