mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	fix unflatten vjp (#1708)
This commit is contained in:
		| @@ -605,6 +605,21 @@ class TestAutograd(mlx_tests.MLXTestCase): | ||||
|         dfdx = mx.grad(fun)(x) | ||||
|         self.assertTrue(mx.allclose(dfdx, -2j * mx.ones_like(x))) | ||||
|  | ||||
|     def test_flatten_unflatten_vjps(self): | ||||
|         def fun(x): | ||||
|             y = mx.unflatten(x, 0, (2, 2)) | ||||
|             return y.sum() | ||||
|  | ||||
|         x = mx.zeros((4, 8)) | ||||
|         self.assertEqual(mx.grad(fun)(x).shape, (4, 8)) | ||||
|  | ||||
|         def fun(x): | ||||
|             y = mx.flatten(x, 0, 2) | ||||
|             return y.sum() | ||||
|  | ||||
|         x = mx.zeros((2, 4, 8)) | ||||
|         self.assertEqual(mx.grad(fun)(x).shape, (2, 4, 8)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun