mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-10 22:46:53 +08:00
scatter axis + gather axis primitives (#1813)
* scatter axis + gather axis primitives * add transforms * comment
This commit is contained in:
@@ -669,6 +669,37 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
_, (expected,) = mx.jvp(lambda c: mx.addmm(c, a, b), (c,), (z,))
|
||||
self.assertTrue(mx.allclose(tangent, expected))
|
||||
|
||||
def test_put_along_axis_grads(self):
|
||||
a = mx.zeros((5, 1))
|
||||
b = mx.ones((2, 1))
|
||||
|
||||
def fun(a, b):
|
||||
idx = mx.array([[0], [3]])
|
||||
return mx.put_along_axis(a, idx, b, axis=0)
|
||||
|
||||
# Test VJP
|
||||
cotan = mx.full((5, 1), 2.0)
|
||||
_, (da, db) = mx.vjp(fun, (a, b), (cotan,))
|
||||
expected_da = mx.array([0.0, 2.0, 2.0, 0.0, 2.0])[:, None]
|
||||
expected_db = mx.array([2.0, 2.0])[:, None]
|
||||
self.assertTrue(mx.allclose(expected_da, da))
|
||||
self.assertTrue(mx.allclose(expected_db, db))
|
||||
|
||||
# Test JVP
|
||||
tan_a = mx.full((5, 1), 2.0)
|
||||
tan_b = mx.full((2, 1), 3.0)
|
||||
_, (jout,) = mx.jvp(fun, (a, b), (tan_a, tan_b))
|
||||
expected = mx.array([3.0, 2.0, 2.0, 3.0, 2.0])[:, None]
|
||||
self.assertTrue(mx.allclose(expected, jout))
|
||||
|
||||
def fun(a):
|
||||
idx = mx.array([[0], [3]])
|
||||
return mx.put_along_axis(a, idx, b, axis=0)
|
||||
|
||||
_, (jout,) = mx.jvp(fun, (a,), (tan_a,))
|
||||
expected = mx.array([0.0, 2.0, 2.0, 0.0, 2.0])[:, None]
|
||||
self.assertTrue(mx.allclose(expected, jout))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1150,6 +1150,15 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out_mlx = mx.put_along_axis(a_mlx, idx_mlx, values_mlx, axis=ax)
|
||||
self.assertTrue(np.array_equal(a_np, out_mlx))
|
||||
|
||||
source = mx.zeros((1, 1, 8, 32))
|
||||
indices = mx.array([0, 2, 4, 5]).reshape((1, 1, 4, 1))
|
||||
update = mx.array(1.0)
|
||||
|
||||
out_mlx = mx.put_along_axis(source, indices, update, axis=-2)
|
||||
out_np = np.array(source)
|
||||
np.put_along_axis(out_np, np.array(indices), np.array(update), axis=-2)
|
||||
self.assertTrue(np.array_equal(out_np, np.array(out_mlx)))
|
||||
|
||||
def test_split(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
splits = mx.split(a, 3)
|
||||
|
||||
@@ -549,6 +549,53 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
target = mx.concatenate([x, mx.ones((2, 2, 1))], axis=2)
|
||||
self.assertTrue(mx.array_equal(out, target))
|
||||
|
||||
def test_vmap_take_along_axis(self):
|
||||
a = mx.zeros((4, 5, 1))
|
||||
idx = mx.zeros((2, 4, 1), mx.int32)
|
||||
|
||||
def fun(a, idx):
|
||||
return mx.take_along_axis(a, idx, axis=0)
|
||||
|
||||
out = mx.vmap(fun, in_axes=(0, 1))(a, idx)
|
||||
self.assertEqual(out.shape, (4, 2, 1))
|
||||
|
||||
idx = mx.zeros((2, 1), mx.int32)
|
||||
|
||||
out = mx.vmap(fun, in_axes=(0, None))(a, idx)
|
||||
self.assertEqual(out.shape, (4, 2, 1))
|
||||
|
||||
a = mx.zeros((5, 1))
|
||||
idx = mx.zeros((4, 2, 1), mx.int32)
|
||||
|
||||
out = mx.vmap(fun, in_axes=(None, 0))(a, idx)
|
||||
self.assertEqual(out.shape, (4, 2, 1))
|
||||
|
||||
def test_vmap_put_along_axis(self):
|
||||
a = mx.zeros((4, 5, 1))
|
||||
idx = mx.ones((2, 4, 1), mx.int32)
|
||||
upd = mx.ones((2, 4, 1))
|
||||
|
||||
def fun(a, idx, upd):
|
||||
return mx.put_along_axis(a, idx, upd, axis=0)
|
||||
|
||||
out = mx.vmap(fun, in_axes=(0, 1, 1))(a, idx, upd)
|
||||
self.assertEqual(out.shape, (4, 5, 1))
|
||||
|
||||
upd = mx.ones((2, 1))
|
||||
out = mx.vmap(fun, in_axes=(0, 1, None))(a, idx, upd)
|
||||
self.assertEqual(out.shape, (4, 5, 1))
|
||||
|
||||
idx = mx.ones((2, 1), mx.int32)
|
||||
upd = mx.ones((2, 1))
|
||||
out = mx.vmap(fun, in_axes=(0, None, None))(a, idx, upd)
|
||||
self.assertEqual(out.shape, (4, 5, 1))
|
||||
|
||||
a = mx.zeros((5, 1))
|
||||
idx = mx.ones((2, 4, 1), mx.int32)
|
||||
upd = mx.ones((2, 4, 1))
|
||||
out = mx.vmap(fun, in_axes=(None, 1, 1))(a, idx, upd)
|
||||
self.assertEqual(out.shape, (4, 5, 1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user