diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index e078a4b40d..f3d48dda36 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1259,7 +1259,7 @@ class TestOps(mlx_tests.MLXTestCase): b = mx.put_along_axis(a, a, a, axis=None) mx.eval(b) self.assertEqual(b.size, 0) - self.assertEqual(b.shape, tuple()) + self.assertEqual(b.shape, a.shape) def test_split(self): a = mx.array([1, 2, 3])