mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-03 22:34:43 +08:00
Add vmap to scatter (#1200)
* Add vmap to scatter * updates * vmap updates + a few more tests * bug fix --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -370,6 +370,98 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5)
|
||||
)
|
||||
|
||||
def test_vmap_scatter(self):
|
||||
def scatter(a):
|
||||
a[mx.array(0)] = mx.array(0.0)
|
||||
return a
|
||||
|
||||
a = mx.array([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]])
|
||||
out = mx.vmap(scatter)(a)
|
||||
expected = mx.array([[0.0, 2.0, 3.0], [0.0, 3.0, 4.0]])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
out = mx.vmap(scatter, in_axes=(1,), out_axes=1)(a)
|
||||
expected = mx.array([[0.0, 0.0, 0.0], [2.0, 3.0, 4.0]])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
def scatter_add(a):
|
||||
return a.at[mx.array(0)].add(mx.array(1.0))
|
||||
|
||||
a = mx.array([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]])
|
||||
out = mx.vmap(scatter_add)(a)
|
||||
expected = mx.array([[2.0, 2.0, 3.0], [3.0, 3.0, 4.0]])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
out = mx.vmap(scatter_add, in_axes=(1,), out_axes=1)(a)
|
||||
expected = mx.array([[2.0, 3.0, 4.0], [2.0, 3.0, 4.0]])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
# Multiple indices
|
||||
def scatter(a):
|
||||
a[mx.array([0, 1]), mx.array([0, 1])] = mx.array((1.0, 1.0))
|
||||
return a
|
||||
|
||||
a = mx.zeros((3, 3, 3))
|
||||
|
||||
expected = mx.repeat(scatter(mx.zeros((3, 3)))[None], 3, axis=0)
|
||||
out = mx.vmap(scatter, in_axes=(0,), out_axes=0)(a)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
expected = mx.zeros((3, 3, 3))
|
||||
expected[0, :, 0] = 1
|
||||
expected[1, :, 1] = 1
|
||||
out = mx.vmap(scatter, in_axes=(1,), out_axes=1)(a)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
expected = mx.zeros((3, 3, 3))
|
||||
expected[0, 0, :] = 1
|
||||
expected[1, 1, :] = 1
|
||||
out = mx.vmap(scatter, in_axes=(2,), out_axes=2)(a)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
# vmap over src and indices
|
||||
def scatter(a, idx):
|
||||
a[idx] = mx.array(1.0)
|
||||
return a
|
||||
|
||||
a = mx.zeros((3, 4))
|
||||
idx = mx.array([0, 1, 2])
|
||||
out = mx.vmap(scatter, in_axes=(0, 0), out_axes=0)(a, idx)
|
||||
self.assertTrue(mx.allclose(out, mx.eye(n=3, m=4)))
|
||||
|
||||
# vmap over only indices
|
||||
out = mx.vmap(scatter, in_axes=(None, 0), out_axes=0)(a, idx)
|
||||
expected = mx.zeros((3, 3, 4))
|
||||
expected[0, 0] = 1
|
||||
expected[1, 1] = 1
|
||||
expected[2, 2] = 1
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
# vmap over src, indices, updates
|
||||
def scatter(a, idx, updates):
|
||||
a[idx] = updates
|
||||
return a
|
||||
|
||||
a = mx.zeros((3, 4))
|
||||
idx = mx.array([0, 1, 2])
|
||||
updates = mx.array([1, 2, 3])
|
||||
out = mx.vmap(scatter, in_axes=(0, 0, 0), out_axes=0)(a, idx, updates)
|
||||
expected = mx.diag(mx.array([1, 2, 3]), k=-1)[1:]
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
# vmap over only updates
|
||||
def scatter(a, idx, updates):
|
||||
a[idx] = updates
|
||||
return a
|
||||
|
||||
a = mx.zeros((3, 4))
|
||||
idx = mx.array([0])
|
||||
updates = mx.array([1, 2, 3])
|
||||
out = mx.vmap(scatter, in_axes=(None, None, 0), out_axes=0)(a, idx, updates)
|
||||
expected = mx.zeros((3, 3, 4))
|
||||
expected[:, 0] = mx.array([1, 2, 3])[:, None]
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user