mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 09:58:17 +08:00
add zero vjps for bitwise ops and gather w.r.t. index (#1256)
This commit is contained in:
@@ -2373,6 +2373,21 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(c.shape, (3, 2, 5))
|
||||
self.assertTrue(mx.array_equal(c, mx.ones((3, 2, 5), dtype=mx.bool_)))
|
||||
|
||||
def test_bitwise_grad(self):
|
||||
a = np.random.randint(0, 10, size=(4, 3))
|
||||
b = np.random.randint(0, 10, size=(4, 3))
|
||||
cotangent = np.random.randint(0, 10, size=(4, 3))
|
||||
a = mx.array(a)
|
||||
b = mx.array(b)
|
||||
cotangent = mx.array(cotangent)
|
||||
|
||||
def bitwise(a, b):
|
||||
return a.astype(mx.int32) & b.astype(mx.int32)
|
||||
|
||||
_, vjps = mx.vjp(bitwise, [a, b], [cotangent])
|
||||
for vjp in vjps:
|
||||
self.assertFalse(np.any(np.array(vjp)))
|
||||
|
||||
def test_conjugate(self):
|
||||
shape = (3, 5, 7)
|
||||
a = np.random.normal(size=shape) + 1j * np.random.normal(size=shape)
|
||||
|
||||
Reference in New Issue
Block a user