mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Add scatter_min VJP (#462)
This commit is contained in:
		| @@ -316,6 +316,29 @@ class TestAutograd(mlx_tests.MLXTestCase): | ||||
|         self.assertTrue(mx.allclose(vjps[0], mx.array([[4.0], [5.0], [6.0]]))) | ||||
|         self.assertTrue(mx.allclose(vjps[1], mx.array([[[5.0]]]))) | ||||
|  | ||||
|     def test_scatter_min_vjp(self): | ||||
|         def fun(src, updates): | ||||
|             x = src.at[1].minimum(updates) | ||||
|             return x | ||||
|  | ||||
|         cotan = mx.array([4.0, 5.0, 6.0]) | ||||
|         _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0]), mx.array([[3.0]])], [cotan]) | ||||
|         mx.eval(vjps) | ||||
|  | ||||
|         # Update larger than value | ||||
|         self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 5.0, 6.0]))) | ||||
|         self.assertTrue(mx.allclose(vjps[1], mx.array([0.0]))) | ||||
|  | ||||
|         cotan = mx.array([[4.0], [5.0], [6.0]]) | ||||
|         _, vjps = mx.vjp( | ||||
|             fun, [mx.array([[1.0], [2.0], [3.0]]), mx.array([[[2.0]]])], [cotan] | ||||
|         ) | ||||
|         mx.eval(vjps) | ||||
|  | ||||
|         # Update and value are equal | ||||
|         self.assertTrue(mx.allclose(vjps[0], mx.array([[4.0], [5.0], [6.0]]))) | ||||
|         self.assertTrue(mx.allclose(vjps[1], mx.array([[[5.0]]]))) | ||||
|  | ||||
|     def test_vjp_types(self): | ||||
|         def fun(x): | ||||
|             return x | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Tristan Bilot
					Tristan Bilot