Add vmap to scatter (#1200)

* Add vmap to scatter

* updates

* vmap updates + a few more tests

* bug fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
nicolov
2024-08-06 05:12:27 +02:00
committed by GitHub
parent 58d0e199e1
commit 8c9f0278b9
6 changed files with 269 additions and 8 deletions

View File

@@ -370,6 +370,98 @@ class TestVmap(mlx_tests.MLXTestCase):
mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5)
)
def test_vmap_scatter(self):
def scatter(a):
a[mx.array(0)] = mx.array(0.0)
return a
a = mx.array([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]])
out = mx.vmap(scatter)(a)
expected = mx.array([[0.0, 2.0, 3.0], [0.0, 3.0, 4.0]])
self.assertTrue(mx.allclose(out, expected))
out = mx.vmap(scatter, in_axes=(1,), out_axes=1)(a)
expected = mx.array([[0.0, 0.0, 0.0], [2.0, 3.0, 4.0]])
self.assertTrue(mx.allclose(out, expected))
def scatter_add(a):
return a.at[mx.array(0)].add(mx.array(1.0))
a = mx.array([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]])
out = mx.vmap(scatter_add)(a)
expected = mx.array([[2.0, 2.0, 3.0], [3.0, 3.0, 4.0]])
self.assertTrue(mx.allclose(out, expected))
out = mx.vmap(scatter_add, in_axes=(1,), out_axes=1)(a)
expected = mx.array([[2.0, 3.0, 4.0], [2.0, 3.0, 4.0]])
self.assertTrue(mx.allclose(out, expected))
# Multiple indices
def scatter(a):
a[mx.array([0, 1]), mx.array([0, 1])] = mx.array((1.0, 1.0))
return a
a = mx.zeros((3, 3, 3))
expected = mx.repeat(scatter(mx.zeros((3, 3)))[None], 3, axis=0)
out = mx.vmap(scatter, in_axes=(0,), out_axes=0)(a)
self.assertTrue(mx.allclose(out, expected))
expected = mx.zeros((3, 3, 3))
expected[0, :, 0] = 1
expected[1, :, 1] = 1
out = mx.vmap(scatter, in_axes=(1,), out_axes=1)(a)
self.assertTrue(mx.allclose(out, expected))
expected = mx.zeros((3, 3, 3))
expected[0, 0, :] = 1
expected[1, 1, :] = 1
out = mx.vmap(scatter, in_axes=(2,), out_axes=2)(a)
self.assertTrue(mx.allclose(out, expected))
# vmap over src and indices
def scatter(a, idx):
a[idx] = mx.array(1.0)
return a
a = mx.zeros((3, 4))
idx = mx.array([0, 1, 2])
out = mx.vmap(scatter, in_axes=(0, 0), out_axes=0)(a, idx)
self.assertTrue(mx.allclose(out, mx.eye(n=3, m=4)))
# vmap over only indices
out = mx.vmap(scatter, in_axes=(None, 0), out_axes=0)(a, idx)
expected = mx.zeros((3, 3, 4))
expected[0, 0] = 1
expected[1, 1] = 1
expected[2, 2] = 1
self.assertTrue(mx.allclose(out, expected))
# vmap over src, indices, updates
def scatter(a, idx, updates):
a[idx] = updates
return a
a = mx.zeros((3, 4))
idx = mx.array([0, 1, 2])
updates = mx.array([1, 2, 3])
out = mx.vmap(scatter, in_axes=(0, 0, 0), out_axes=0)(a, idx, updates)
expected = mx.diag(mx.array([1, 2, 3]), k=-1)[1:]
self.assertTrue(mx.allclose(out, expected))
# vmap over only updates
def scatter(a, idx, updates):
a[idx] = updates
return a
a = mx.zeros((3, 4))
idx = mx.array([0])
updates = mx.array([1, 2, 3])
out = mx.vmap(scatter, in_axes=(None, None, 0), out_axes=0)(a, idx, updates)
expected = mx.zeros((3, 3, 4))
expected[:, 0] = mx.array([1, 2, 3])[:, None]
self.assertTrue(mx.allclose(out, expected))
if __name__ == "__main__":
unittest.main()