mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	add zero vjps for bitwise ops and gather w.r.t. index (#1256)
This commit is contained in:
		@@ -2373,6 +2373,21 @@ class TestOps(mlx_tests.MLXTestCase):
 | 
			
		||||
        self.assertEqual(c.shape, (3, 2, 5))
 | 
			
		||||
        self.assertTrue(mx.array_equal(c, mx.ones((3, 2, 5), dtype=mx.bool_)))
 | 
			
		||||
 | 
			
		||||
    def test_bitwise_grad(self):
 | 
			
		||||
        a = np.random.randint(0, 10, size=(4, 3))
 | 
			
		||||
        b = np.random.randint(0, 10, size=(4, 3))
 | 
			
		||||
        cotangent = np.random.randint(0, 10, size=(4, 3))
 | 
			
		||||
        a = mx.array(a)
 | 
			
		||||
        b = mx.array(b)
 | 
			
		||||
        cotangent = mx.array(cotangent)
 | 
			
		||||
 | 
			
		||||
        def bitwise(a, b):
 | 
			
		||||
            return a.astype(mx.int32) & b.astype(mx.int32)
 | 
			
		||||
 | 
			
		||||
        _, vjps = mx.vjp(bitwise, [a, b], [cotangent])
 | 
			
		||||
        for vjp in vjps:
 | 
			
		||||
            self.assertFalse(np.any(np.array(vjp)))
 | 
			
		||||
 | 
			
		||||
    def test_conjugate(self):
 | 
			
		||||
        shape = (3, 5, 7)
 | 
			
		||||
        a = np.random.normal(size=shape) + 1j * np.random.normal(size=shape)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user