mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
add zero vjps for bitwise ops and gather w.r.t. index (#1256)
This commit is contained in:
@@ -1043,6 +1043,19 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
a_mlx = mx.array(a_np)
|
||||
self.assertTrue(np.array_equal(a_np[2:-1, 0], np.array(a_mlx[2:-1, 0])))
|
||||
|
||||
def test_indexing_grad(self):
|
||||
x = mx.array([[1, 2], [3, 4]]).astype(mx.float32)
|
||||
ind = mx.array([0, 1, 0]).astype(mx.float32)
|
||||
|
||||
def index_fn(x, ind):
|
||||
return x[ind.astype(mx.int32)].sum()
|
||||
|
||||
grad_x, grad_ind = mx.grad(index_fn, argnums=(0, 1))(x, ind)
|
||||
expected = mx.array([[2, 2], [1, 1]])
|
||||
|
||||
self.assertTrue(mx.array_equal(grad_x, expected))
|
||||
self.assertTrue(mx.array_equal(grad_ind, mx.zeros(ind.shape)))
|
||||
|
||||
def test_setitem(self):
|
||||
a = mx.array(0)
|
||||
a[None] = 1
|
||||
|
||||
Reference in New Issue
Block a user