Fix shape check

This commit is contained in:
Angelos Katharopoulos 2025-05-13 13:21:34 -07:00
parent 134ed4a58a
commit 40d2fc1263

View File

@ -1259,7 +1259,7 @@ class TestOps(mlx_tests.MLXTestCase):
b = mx.put_along_axis(a, a, a, axis=None)
mx.eval(b)
self.assertEqual(b.size, 0)
self.assertEqual(b.shape, tuple())
self.assertEqual(b.shape, a.shape)
def test_split(self):
a = mx.array([1, 2, 3])