mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:38:09 +08:00
scatter axis + gather axis primitives (#1813)
* scatter axis + gather axis primitives * add transforms * comment
This commit is contained in:
@@ -549,6 +549,53 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
target = mx.concatenate([x, mx.ones((2, 2, 1))], axis=2)
|
||||
self.assertTrue(mx.array_equal(out, target))
|
||||
|
||||
def test_vmap_take_along_axis(self):
|
||||
a = mx.zeros((4, 5, 1))
|
||||
idx = mx.zeros((2, 4, 1), mx.int32)
|
||||
|
||||
def fun(a, idx):
|
||||
return mx.take_along_axis(a, idx, axis=0)
|
||||
|
||||
out = mx.vmap(fun, in_axes=(0, 1))(a, idx)
|
||||
self.assertEqual(out.shape, (4, 2, 1))
|
||||
|
||||
idx = mx.zeros((2, 1), mx.int32)
|
||||
|
||||
out = mx.vmap(fun, in_axes=(0, None))(a, idx)
|
||||
self.assertEqual(out.shape, (4, 2, 1))
|
||||
|
||||
a = mx.zeros((5, 1))
|
||||
idx = mx.zeros((4, 2, 1), mx.int32)
|
||||
|
||||
out = mx.vmap(fun, in_axes=(None, 0))(a, idx)
|
||||
self.assertEqual(out.shape, (4, 2, 1))
|
||||
|
||||
def test_vmap_put_along_axis(self):
|
||||
a = mx.zeros((4, 5, 1))
|
||||
idx = mx.ones((2, 4, 1), mx.int32)
|
||||
upd = mx.ones((2, 4, 1))
|
||||
|
||||
def fun(a, idx, upd):
|
||||
return mx.put_along_axis(a, idx, upd, axis=0)
|
||||
|
||||
out = mx.vmap(fun, in_axes=(0, 1, 1))(a, idx, upd)
|
||||
self.assertEqual(out.shape, (4, 5, 1))
|
||||
|
||||
upd = mx.ones((2, 1))
|
||||
out = mx.vmap(fun, in_axes=(0, 1, None))(a, idx, upd)
|
||||
self.assertEqual(out.shape, (4, 5, 1))
|
||||
|
||||
idx = mx.ones((2, 1), mx.int32)
|
||||
upd = mx.ones((2, 1))
|
||||
out = mx.vmap(fun, in_axes=(0, None, None))(a, idx, upd)
|
||||
self.assertEqual(out.shape, (4, 5, 1))
|
||||
|
||||
a = mx.zeros((5, 1))
|
||||
idx = mx.ones((2, 4, 1), mx.int32)
|
||||
upd = mx.ones((2, 4, 1))
|
||||
out = mx.vmap(fun, in_axes=(None, 1, 1))(a, idx, upd)
|
||||
self.assertEqual(out.shape, (4, 5, 1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user