mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-08 01:54:37 +08:00
Scatter vjp (#394)
* Add a first scatter vjp * Implement the scatter_add vjp * Add array.at to implement user friendly scatters
This commit is contained in:

committed by
GitHub

parent
e9ca65c939
commit
961435a243
@@ -984,6 +984,53 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
a[2:-2, 2:-2] = 4
|
||||
self.assertEqual(a[2, 2].item(), 4)
|
||||
|
||||
def test_array_at(self):
|
||||
a = mx.array(1)
|
||||
a = a.at[None].add(1)
|
||||
self.assertEqual(a.item(), 2)
|
||||
|
||||
a = mx.array([0, 1, 2])
|
||||
a = a.at[1].add(2)
|
||||
self.assertEqual(a.tolist(), [0, 3, 2])
|
||||
|
||||
a = a.at[mx.array([0, 0, 0, 0])].add(1)
|
||||
self.assertEqual(a.tolist(), [4, 3, 2])
|
||||
|
||||
a = mx.zeros((10, 10))
|
||||
a = a.at[0].add(mx.arange(10))
|
||||
self.assertEqual(a[0].tolist(), list(range(10)))
|
||||
|
||||
a = mx.zeros((10, 10))
|
||||
index_x = mx.array([0, 2, 3, 7])
|
||||
index_y = mx.array([3, 3, 1, 2])
|
||||
u = mx.random.uniform(shape=(4,))
|
||||
a = a.at[index_x, index_y].add(u)
|
||||
self.assertEqual(a.sum().item(), u.sum().item())
|
||||
self.assertEqual(a[index_x, index_y].tolist(), u.tolist())
|
||||
|
||||
# Test all array.at ops
|
||||
a = mx.random.uniform(shape=(10, 5, 2))
|
||||
idx_x = mx.array([0, 4])
|
||||
update = mx.ones((2, 5))
|
||||
a[idx_x, :, 0] = 0
|
||||
a = a.at[idx_x, :, 0].add(update)
|
||||
self.assertEqualArray(a[idx_x, :, 0], update)
|
||||
a = a.at[idx_x, :, 0].subtract(update)
|
||||
self.assertEqualArray(a[idx_x, :, 0], mx.zeros_like(update))
|
||||
a = a.at[idx_x, :, 0].add(2 * update)
|
||||
self.assertEqualArray(a[idx_x, :, 0], 2 * update)
|
||||
a = a.at[idx_x, :, 0].multiply(2 * update)
|
||||
self.assertEqualArray(a[idx_x, :, 0], 4 * update)
|
||||
a = a.at[idx_x, :, 0].divide(3 * update)
|
||||
self.assertEqualArray(a[idx_x, :, 0], (4 / 3) * update)
|
||||
a[idx_x, :, 0] = 5
|
||||
update = mx.arange(10).reshape(2, 5)
|
||||
a = a.at[idx_x, :, 0].maximum(update)
|
||||
self.assertEqualArray(a[idx_x, :, 0], mx.maximum(a[idx_x, :, 0], update))
|
||||
a[idx_x, :, 0] = 5
|
||||
a = a.at[idx_x, :, 0].minimum(update)
|
||||
self.assertEqualArray(a[idx_x, :, 0], mx.minimum(a[idx_x, :, 0], update))
|
||||
|
||||
def test_slice_negative_step(self):
|
||||
a_np = np.arange(20)
|
||||
a_mx = mx.array(a_np)
|
||||
|
Reference in New Issue
Block a user