mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Primitive's VJP takes outputs as input (#475)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
		| @@ -71,3 +71,5 @@ class MLXTestCase(unittest.TestCase): | ||||
|         elif not isinstance(expected, mx.array): | ||||
|             expected = mx.array(expected) | ||||
|             self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol)) | ||||
|         else: | ||||
|             self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol)) | ||||
|   | ||||
| @@ -1005,7 +1005,8 @@ class TestArray(mlx_tests.MLXTestCase): | ||||
|         index_y = mx.array([3, 3, 1, 2]) | ||||
|         u = mx.random.uniform(shape=(4,)) | ||||
|         a = a.at[index_x, index_y].add(u) | ||||
|         self.assertEqual(a.sum().item(), u.sum().item()) | ||||
|         self.assertTrue(mx.allclose(a.sum(), u.sum())) | ||||
|         self.assertEqualArray(a.sum(), u.sum(), atol=1e-6, rtol=1e-5) | ||||
|         self.assertEqual(a[index_x, index_y].tolist(), u.tolist()) | ||||
|  | ||||
|         # Test all array.at ops | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun