Scatter vjp (#394)

* Add a first scatter vjp
* Implement the scatter_add vjp
* Add array.at to implement user friendly scatters
This commit is contained in:
Angelos Katharopoulos
2024-01-09 13:36:51 -08:00
committed by GitHub
parent e9ca65c939
commit 961435a243
7 changed files with 360 additions and 33 deletions

View File

@@ -984,6 +984,53 @@ class TestArray(mlx_tests.MLXTestCase):
a[2:-2, 2:-2] = 4
self.assertEqual(a[2, 2].item(), 4)
def test_array_at(self):
a = mx.array(1)
a = a.at[None].add(1)
self.assertEqual(a.item(), 2)
a = mx.array([0, 1, 2])
a = a.at[1].add(2)
self.assertEqual(a.tolist(), [0, 3, 2])
a = a.at[mx.array([0, 0, 0, 0])].add(1)
self.assertEqual(a.tolist(), [4, 3, 2])
a = mx.zeros((10, 10))
a = a.at[0].add(mx.arange(10))
self.assertEqual(a[0].tolist(), list(range(10)))
a = mx.zeros((10, 10))
index_x = mx.array([0, 2, 3, 7])
index_y = mx.array([3, 3, 1, 2])
u = mx.random.uniform(shape=(4,))
a = a.at[index_x, index_y].add(u)
self.assertEqual(a.sum().item(), u.sum().item())
self.assertEqual(a[index_x, index_y].tolist(), u.tolist())
# Test all array.at ops
a = mx.random.uniform(shape=(10, 5, 2))
idx_x = mx.array([0, 4])
update = mx.ones((2, 5))
a[idx_x, :, 0] = 0
a = a.at[idx_x, :, 0].add(update)
self.assertEqualArray(a[idx_x, :, 0], update)
a = a.at[idx_x, :, 0].subtract(update)
self.assertEqualArray(a[idx_x, :, 0], mx.zeros_like(update))
a = a.at[idx_x, :, 0].add(2 * update)
self.assertEqualArray(a[idx_x, :, 0], 2 * update)
a = a.at[idx_x, :, 0].multiply(2 * update)
self.assertEqualArray(a[idx_x, :, 0], 4 * update)
a = a.at[idx_x, :, 0].divide(3 * update)
self.assertEqualArray(a[idx_x, :, 0], (4 / 3) * update)
a[idx_x, :, 0] = 5
update = mx.arange(10).reshape(2, 5)
a = a.at[idx_x, :, 0].maximum(update)
self.assertEqualArray(a[idx_x, :, 0], mx.maximum(a[idx_x, :, 0], update))
a[idx_x, :, 0] = 5
a = a.at[idx_x, :, 0].minimum(update)
self.assertEqualArray(a[idx_x, :, 0], mx.minimum(a[idx_x, :, 0], update))
def test_slice_negative_step(self):
a_np = np.arange(20)
a_mx = mx.array(a_np)