mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-03 22:34:43 +08:00
@@ -370,6 +370,51 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5)
|
||||
)
|
||||
|
||||
def test_vmap_gather(self):
|
||||
def gather(a, idx):
|
||||
return a[idx]
|
||||
|
||||
a = mx.array([[1, 2], [3, 4]])
|
||||
idx = mx.array(0)
|
||||
out = mx.vmap(gather, (0, None))(a, idx)
|
||||
self.assertTrue(mx.array_equal(out, mx.array([1, 3])))
|
||||
|
||||
out = mx.vmap(gather, (1, None))(a, idx)
|
||||
self.assertTrue(mx.array_equal(out, mx.array([1, 2])))
|
||||
|
||||
idx = mx.array([0, 1])
|
||||
out = mx.vmap(gather, (0, 0))(a, idx)
|
||||
self.assertTrue(mx.array_equal(out, mx.array([1, 4])))
|
||||
|
||||
a = mx.ones((2, 3, 4))
|
||||
idx = mx.zeros(4, mx.int32)
|
||||
out = mx.vmap(gather, (2, 0))(a, idx)
|
||||
self.assertEqual(out.shape, (4, 3))
|
||||
|
||||
f = mx.vmap(gather, (0, None))
|
||||
f = mx.vmap(gather, (0, 0))
|
||||
out = f(mx.ones((2, 3, 4)), mx.zeros(2, dtype=mx.int32))
|
||||
self.assertEqual(out.shape, (2, 4))
|
||||
|
||||
def gather(a, idxa, idxb):
|
||||
return a[idxa, idxb]
|
||||
|
||||
a = mx.ones((2, 3, 4))
|
||||
idxa = mx.zeros((2, 3), mx.int32)
|
||||
idxb = mx.zeros(3, mx.int32)
|
||||
out = mx.vmap(gather, (0, 0, None))(a, idxa, idxb)
|
||||
self.assertEqual(out.shape, (2, 3))
|
||||
|
||||
idxa = mx.zeros((3, 1, 2), mx.int32)
|
||||
idxb = mx.zeros((2, 3, 1, 2), mx.int32)
|
||||
out = mx.vmap(gather, (0, None, 0))(a, idxa, idxb)
|
||||
self.assertEqual(out.shape, (2, 3, 1, 2))
|
||||
|
||||
idxa = mx.zeros((3, 1, 2), mx.int32)
|
||||
idxb = mx.zeros((3, 1, 2, 2), mx.int32)
|
||||
out = mx.vmap(gather, (0, None, 3))(a, idxa, idxb)
|
||||
self.assertEqual(out.shape, (2, 3, 1, 2))
|
||||
|
||||
def test_vmap_scatter(self):
|
||||
def scatter(a):
|
||||
a[mx.array(0)] = mx.array(0.0)
|
||||
|
Reference in New Issue
Block a user