mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Compile primitive (#571)
* Compiled primitive with basic binary, unary graph-level fusion
This commit is contained in:
		| @@ -190,6 +190,117 @@ class TestCompile(mlx_tests.MLXTestCase): | ||||
|         n_enable_compiled = count_prims(cfun(x)) | ||||
|         self.assertEqual(n_compiled, n_enable_compiled) | ||||
|  | ||||
|     def test_compile_two_input_grad(self): | ||||
|         def loss(w, x): | ||||
|             y = x * w | ||||
|             return (y * mx.exp(y)).sum() | ||||
|  | ||||
|         x = mx.array([1.0, 0.5, 2.0, -0.5]) | ||||
|         w = mx.array([-1.0, 0.3, 1.0, -0.9]) | ||||
|  | ||||
|         expected_grad = mx.grad(loss)(w, x) | ||||
|         compiled_grad = mx.compile(mx.grad(loss))(w, x) | ||||
|         self.assertTrue(mx.allclose(expected_grad, compiled_grad)) | ||||
|  | ||||
|     def test_vmap_compiled(self): | ||||
|         def simple_unary(x): | ||||
|             return -mx.exp(x) | ||||
|  | ||||
|         x = mx.array([[1.0, 2.0], [2.0, 3.0]]) | ||||
|  | ||||
|         expected_out = mx.vmap(simple_unary)(x) | ||||
|         out = mx.vmap(mx.compile(simple_unary))(x) | ||||
|         self.assertTrue(mx.allclose(expected_out, out)) | ||||
|  | ||||
|         def simple_binary(x, y): | ||||
|             return mx.abs(mx.exp(x + y) + y) | ||||
|  | ||||
|         x = mx.array([[1.0, -3.0], [0.5, -0.5]]) | ||||
|         y = mx.array([[2.0, -1.0], [0.25, -0.25]]) | ||||
|  | ||||
|         expected_out = mx.vmap(simple_binary)(x, y) | ||||
|         out = mx.vmap(mx.compile(simple_binary))(x, y) | ||||
|         self.assertTrue(mx.allclose(expected_out, out)) | ||||
|  | ||||
|         expected_out = mx.vmap(simple_binary, in_axes=(0, 1))(x, y) | ||||
|         out = mx.vmap(mx.compile(simple_binary), in_axes=(0, 1))(x, y) | ||||
|         self.assertTrue(mx.allclose(expected_out, out)) | ||||
|  | ||||
|         y = mx.array([0.25, -0.25]) | ||||
|         expected_out = mx.vmap(simple_binary, in_axes=(0, None))(x, y) | ||||
|         out = mx.vmap(mx.compile(simple_binary), in_axes=(0, None))(x, y) | ||||
|         self.assertTrue(mx.allclose(expected_out, out)) | ||||
|  | ||||
|         def simple_unary_outer(x): | ||||
|             x = mx.abs(x) | ||||
|  | ||||
|             @mx.compile | ||||
|             def simple_unary_inner(z): | ||||
|                 return -mx.exp(x) | ||||
|  | ||||
|             return simple_unary_inner(x) | ||||
|  | ||||
|         expected_out = -mx.exp(mx.abs(x)) | ||||
|         out = mx.vmap(simple_unary_outer)(x) | ||||
|         self.assertTrue(mx.allclose(expected_out, out)) | ||||
|  | ||||
|     def test_vjp_vjp_compiled(self): | ||||
|         def simple_unary(x): | ||||
|             return -mx.exp(x) | ||||
|  | ||||
|         x = mx.array([[1.0, 2.0], [2.0, 3.0]]) | ||||
|         y = mx.array([[1.0, 1.0], [1.0, 1.0]]) | ||||
|  | ||||
|         expected_out, expected_vjp_out = mx.vjp(simple_unary, (x,), (y,)) | ||||
|         out, vjp_out = mx.vjp(mx.compile(simple_unary), (x,), (y,)) | ||||
|         self.assertTrue(mx.allclose(expected_vjp_out[0], vjp_out[0])) | ||||
|         self.assertTrue(mx.allclose(expected_out[0], out[0])) | ||||
|  | ||||
|         expected_out, expected_jvp_out = mx.jvp(simple_unary, (x,), (y,)) | ||||
|         out, jvp_out = mx.jvp(mx.compile(simple_unary), (x,), (y,)) | ||||
|         self.assertTrue(mx.allclose(expected_jvp_out[0], jvp_out[0])) | ||||
|         self.assertTrue(mx.allclose(expected_out[0], out[0])) | ||||
|  | ||||
|         def simple_binary(x, y): | ||||
|             return mx.abs(mx.exp(x + y) + y) | ||||
|  | ||||
|         x = mx.array([[1.0, -3.0], [0.5, -0.5]]) | ||||
|         y = mx.array([[2.0, -1.0], [0.25, -0.25]]) | ||||
|         cotans = mx.ones_like(x) | ||||
|  | ||||
|         expected_out, expected_vjp_out = mx.vjp(simple_binary, (x, y), (cotans,)) | ||||
|         out, vjp_out = mx.vjp(mx.compile(simple_binary), (x, y), (cotans,)) | ||||
|         self.assertTrue(mx.allclose(expected_out[0], out[0])) | ||||
|         self.assertTrue(mx.allclose(expected_vjp_out[0], vjp_out[0])) | ||||
|         self.assertTrue(mx.allclose(expected_vjp_out[1], vjp_out[1])) | ||||
|  | ||||
|         tans = (mx.ones_like(x), mx.ones_like(y)) | ||||
|         expected_out, expected_jvp_out = mx.jvp(simple_binary, (x, y), tans) | ||||
|         out, jvp_out = mx.jvp(mx.compile(simple_binary), (x, y), tans) | ||||
|         self.assertTrue(mx.allclose(expected_jvp_out[0], jvp_out[0])) | ||||
|         self.assertTrue(mx.allclose(expected_out[0], out[0])) | ||||
|  | ||||
|     def test_transform_over_eval_compiled(self): | ||||
|         def outer(x): | ||||
|             y = mx.exp(mx.abs(x)) | ||||
|             mx.eval(y) | ||||
|             return y.sum() | ||||
|  | ||||
|         x = mx.array([2.0, -1.0, 0.5]) | ||||
|         dfdx = mx.grad(outer)(x) | ||||
|  | ||||
|         @mx.compile | ||||
|         def simple_unary(x): | ||||
|             return mx.exp(mx.abs(x)) | ||||
|  | ||||
|         def outer(x): | ||||
|             y = simple_unary(x) | ||||
|             mx.eval(y) | ||||
|             return y.sum() | ||||
|  | ||||
|         cdfdx = mx.grad(outer)(x) | ||||
|         self.assertTrue(mx.allclose(dfdx, cdfdx)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun