mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Correct types for vjp + tests (#418)
* correct types for vjp + tests * fix build + comment
This commit is contained in:
		| @@ -274,6 +274,47 @@ class TestAutograd(mlx_tests.MLXTestCase): | ||||
|         mx.eval(state) | ||||
|         self.assertTrue(mx.allclose(state, mx.ones((2,)))) | ||||
|  | ||||
|     def test_scatter_vjp(self): | ||||
|         def fun(x, idx): | ||||
|             x[idx] = 2.0 | ||||
|             return x.sum() | ||||
|  | ||||
|         dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1])) | ||||
|         self.assertTrue(mx.array_equal(dfdx, mx.array([1.0, 0.0, 1.0]))) | ||||
|         self.assertEqual(dfdx.dtype, mx.float32) | ||||
|  | ||||
|         y = mx.array([0.0, 1.0, 2.0]) | ||||
|  | ||||
|         def fun(x, idx): | ||||
|             y[idx] = x | ||||
|             return y.sum() | ||||
|  | ||||
|         dfdx = mx.grad(fun)(mx.array([2.0]), mx.array([1])) | ||||
|         self.assertTrue(mx.array_equal(dfdx, mx.array([1.0]))) | ||||
|         self.assertEqual(dfdx.dtype, mx.float32) | ||||
|  | ||||
|     def test_vjp_types(self): | ||||
|         def fun(x): | ||||
|             return x | ||||
|  | ||||
|         for t in [mx.float16, mx.bfloat16, mx.float32]: | ||||
|             out = mx.grad(fun)(mx.array(1.0, t)) | ||||
|             self.assertEqual(out.dtype, t) | ||||
|  | ||||
|         def fun(x): | ||||
|             return x.sum() | ||||
|  | ||||
|         for t in [mx.float16, mx.bfloat16, mx.float32]: | ||||
|             out = mx.grad(fun)(mx.array(1.0, t)) | ||||
|             self.assertEqual(out.dtype, t) | ||||
|  | ||||
|         def fun(x, y): | ||||
|             return (x + y).sum() | ||||
|  | ||||
|         for t in [mx.float16, mx.bfloat16, mx.float32]: | ||||
|             out = mx.grad(fun)(mx.array(1.0, t), mx.array(1.0, t)) | ||||
|             self.assertEqual(out.dtype, t) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun