add zero vjps for bitwise ops and gather w.r.t. index (#1256)

This commit is contained in:
Alex Barron
2024-07-07 21:34:59 -07:00
committed by GitHub
parent 20bb301195
commit bdb36c9a63
4 changed files with 62 additions and 7 deletions

View File

@@ -1043,6 +1043,19 @@ class TestArray(mlx_tests.MLXTestCase):
a_mlx = mx.array(a_np)
self.assertTrue(np.array_equal(a_np[2:-1, 0], np.array(a_mlx[2:-1, 0])))
def test_indexing_grad(self):
x = mx.array([[1, 2], [3, 4]]).astype(mx.float32)
ind = mx.array([0, 1, 0]).astype(mx.float32)
def index_fn(x, ind):
return x[ind.astype(mx.int32)].sum()
grad_x, grad_ind = mx.grad(index_fn, argnums=(0, 1))(x, ind)
expected = mx.array([[2, 2], [1, 1]])
self.assertTrue(mx.array_equal(grad_x, expected))
self.assertTrue(mx.array_equal(grad_ind, mx.zeros(ind.shape)))
def test_setitem(self):
a = mx.array(0)
a[None] = 1

View File

@@ -2373,6 +2373,21 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(c.shape, (3, 2, 5))
self.assertTrue(mx.array_equal(c, mx.ones((3, 2, 5), dtype=mx.bool_)))
def test_bitwise_grad(self):
a = np.random.randint(0, 10, size=(4, 3))
b = np.random.randint(0, 10, size=(4, 3))
cotangent = np.random.randint(0, 10, size=(4, 3))
a = mx.array(a)
b = mx.array(b)
cotangent = mx.array(cotangent)
def bitwise(a, b):
return a.astype(mx.int32) & b.astype(mx.int32)
_, vjps = mx.vjp(bitwise, [a, b], [cotangent])
for vjp in vjps:
self.assertFalse(np.any(np.array(vjp)))
def test_conjugate(self):
shape = (3, 5, 7)
a = np.random.normal(size=shape) + 1j * np.random.normal(size=shape)