mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Custom transforms (#1246)
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							a3c287354f
						
					
				
				
					commit
					5c1fa64fb0
				
			| @@ -496,6 +496,90 @@ class TestAutograd(mlx_tests.MLXTestCase): | ||||
|         expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0]) | ||||
|         self.assertTrue(mx.allclose(out, expected)) | ||||
|  | ||||
|     def test_custom_function(self): | ||||
|         # Make a custom function | ||||
|         my_exp = mx.custom_function(mx.exp) | ||||
|  | ||||
|         # Ensure everything works | ||||
|         dy = mx.grad(my_exp)(mx.array(1.0)) | ||||
|         self.assertTrue(mx.allclose(dy, mx.exp(mx.array(1.0)))) | ||||
|         (ex,), (dex,) = mx.jvp(my_exp, [mx.array(1.0)], [mx.array(1.0)]) | ||||
|         self.assertTrue(mx.allclose(dex, mx.exp(mx.array(1.0)))) | ||||
|         self.assertTrue(mx.allclose(ex, dex)) | ||||
|         ex = mx.vmap(my_exp)(mx.ones(10)) | ||||
|         self.assertTrue(mx.allclose(ex, mx.exp(mx.ones(10)))) | ||||
|  | ||||
|         # Ensure that the vjp is being overriden but everything else still | ||||
|         # works. | ||||
|         @my_exp.vjp | ||||
|         def my_exp_vjp(x, dx, ex): | ||||
|             return mx.ones_like(x) * 42 | ||||
|  | ||||
|         dy = mx.grad(my_exp)(mx.array(1.0)) | ||||
|         self.assertTrue(mx.allclose(dy, mx.array(42.0))) | ||||
|         (ex,), (dex,) = mx.jvp(my_exp, [mx.array(1.0)], [mx.array(1.0)]) | ||||
|         self.assertTrue(mx.allclose(dex, mx.exp(mx.array(1.0)))) | ||||
|         self.assertTrue(mx.allclose(ex, dex)) | ||||
|         ex = mx.vmap(my_exp)(mx.ones(10)) | ||||
|         self.assertTrue(mx.allclose(ex, mx.exp(mx.ones(10)))) | ||||
|  | ||||
|         # Ensure that setting the jvp and vmap also works. | ||||
|         @my_exp.jvp | ||||
|         def my_exp_jvp(x, dx): | ||||
|             return mx.ones_like(x) * 7 * dx | ||||
|  | ||||
|         @my_exp.vmap | ||||
|         def my_exp_vmap(x, axis): | ||||
|             return mx.ones_like(x) * 3, axis | ||||
|  | ||||
|         dy = mx.grad(my_exp)(mx.array(1.0)) | ||||
|         self.assertTrue(mx.allclose(dy, mx.array(42.0))) | ||||
|         (ex,), (dex,) = mx.jvp(my_exp, [mx.array(1.0)], [mx.array(1.0)]) | ||||
|         self.assertTrue(mx.allclose(dex, mx.array(7.0))) | ||||
|         self.assertTrue(mx.allclose(ex, mx.exp(mx.array(1.0)))) | ||||
|         ex = mx.vmap(my_exp)(mx.ones(10)) | ||||
|         self.assertTrue(mx.allclose(ex, 3 * mx.ones(10))) | ||||
|  | ||||
|         # Test pytrees | ||||
|         @mx.custom_function | ||||
|         def my_double(params): | ||||
|             return {"out": 2 * params["x"] * params["y"]} | ||||
|  | ||||
|         dy = mx.grad(lambda p: my_double(p)["out"].sum())( | ||||
|             {"x": mx.ones(2), "y": mx.ones(2)} | ||||
|         ) | ||||
|         self.assertTrue(mx.allclose(dy["x"], mx.ones(2) * 2)) | ||||
|         self.assertTrue(mx.allclose(dy["y"], mx.ones(2) * 2)) | ||||
|  | ||||
|         @my_double.vjp | ||||
|         def random_grads(primals, cotangents, outputs): | ||||
|             return {"x": mx.zeros_like(primals["x"]), "y": mx.ones_like(primals["y"])} | ||||
|  | ||||
|         dy = mx.grad(lambda p: my_double(p)["out"].sum())( | ||||
|             {"x": mx.ones(2), "y": mx.ones(2)} | ||||
|         ) | ||||
|         self.assertTrue(mx.allclose(dy["x"], mx.zeros(2))) | ||||
|         self.assertTrue(mx.allclose(dy["y"], mx.ones(2))) | ||||
|  | ||||
|         def outer_f(a, b): | ||||
|             return my_double({"x": a, "y": b})["out"] | ||||
|  | ||||
|         inputs = [mx.random.normal(shape=(2,)) for i in range(2)] | ||||
|         tans = [mx.random.normal(shape=(2,)) for i in range(2)] | ||||
|         out1, dout1 = mx.jvp(outer_f, inputs, tans) | ||||
|  | ||||
|         @my_double.jvp | ||||
|         def random_grads(primals, tangents): | ||||
|             return { | ||||
|                 "out": 2 * primals["x"] * tangents["y"] | ||||
|                 + 2 * primals["y"] * tangents["x"] | ||||
|                 + 1 | ||||
|             } | ||||
|  | ||||
|         out2, dout2 = mx.jvp(outer_f, inputs, tans) | ||||
|         self.assertTrue(mx.allclose(out1[0], out2[0])) | ||||
|         self.assertTrue(mx.allclose(dout1[0] + 1, dout2[0])) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user