Fix gather vmap (#1563)

* fix gather

* fix
This commit is contained in:
Awni Hannun
2024-11-05 11:29:20 -08:00
committed by GitHub
parent 26be608470
commit 54f05e7195
2 changed files with 83 additions and 29 deletions

View File

@@ -370,6 +370,51 @@ class TestVmap(mlx_tests.MLXTestCase):
mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5)
)
def test_vmap_gather(self):
def gather(a, idx):
return a[idx]
a = mx.array([[1, 2], [3, 4]])
idx = mx.array(0)
out = mx.vmap(gather, (0, None))(a, idx)
self.assertTrue(mx.array_equal(out, mx.array([1, 3])))
out = mx.vmap(gather, (1, None))(a, idx)
self.assertTrue(mx.array_equal(out, mx.array([1, 2])))
idx = mx.array([0, 1])
out = mx.vmap(gather, (0, 0))(a, idx)
self.assertTrue(mx.array_equal(out, mx.array([1, 4])))
a = mx.ones((2, 3, 4))
idx = mx.zeros(4, mx.int32)
out = mx.vmap(gather, (2, 0))(a, idx)
self.assertEqual(out.shape, (4, 3))
f = mx.vmap(gather, (0, None))
f = mx.vmap(gather, (0, 0))
out = f(mx.ones((2, 3, 4)), mx.zeros(2, dtype=mx.int32))
self.assertEqual(out.shape, (2, 4))
def gather(a, idxa, idxb):
return a[idxa, idxb]
a = mx.ones((2, 3, 4))
idxa = mx.zeros((2, 3), mx.int32)
idxb = mx.zeros(3, mx.int32)
out = mx.vmap(gather, (0, 0, None))(a, idxa, idxb)
self.assertEqual(out.shape, (2, 3))
idxa = mx.zeros((3, 1, 2), mx.int32)
idxb = mx.zeros((2, 3, 1, 2), mx.int32)
out = mx.vmap(gather, (0, None, 0))(a, idxa, idxb)
self.assertEqual(out.shape, (2, 3, 1, 2))
idxa = mx.zeros((3, 1, 2), mx.int32)
idxb = mx.zeros((3, 1, 2, 2), mx.int32)
out = mx.vmap(gather, (0, None, 3))(a, idxa, idxb)
self.assertEqual(out.shape, (2, 3, 1, 2))
def test_vmap_scatter(self):
def scatter(a):
a[mx.array(0)] = mx.array(0.0)