mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Compile with capture (#629)
* Simple kernel generation * Remove the generate kernel from graph_utils * fix multi-output with compile * fuse with stopgrad * v1 input, output capture in compile * cleanup tree update with visitor update * nit * remove todo * state for model, optional explicit init and more pure optimizer steps * move learning rate to state * add lr to opt state, some fixes in capture * fix optim * update tuple of containers as well * fix stream for compiled output * rng state for compile * nit * updates and comments --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
		| @@ -2,6 +2,7 @@ | ||||
|  | ||||
| import io | ||||
| import unittest | ||||
| from functools import partial | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx_tests | ||||
| @@ -301,6 +302,85 @@ class TestCompile(mlx_tests.MLXTestCase): | ||||
|         cdfdx = mx.grad(outer)(x) | ||||
|         self.assertTrue(mx.allclose(dfdx, cdfdx)) | ||||
|  | ||||
|     def test_compile_capture(self): | ||||
|         # Test update captured state outside compiled function | ||||
|         state = {"y": mx.array(2)} | ||||
|  | ||||
|         @partial(mx.compile, inputs=state) | ||||
|         def test_state(x): | ||||
|             x = x + state["y"] | ||||
|             return x | ||||
|  | ||||
|         test_state(mx.array(1)) | ||||
|         # Check the state is unchanged | ||||
|         self.assertEqual(state["y"], 2) | ||||
|  | ||||
|         # Check the udpated state is used | ||||
|         state["y"] = mx.array(3) | ||||
|         out = test_state(mx.array(1)) | ||||
|         self.assertEqual(out.item(), 4) | ||||
|  | ||||
|         # Capture list | ||||
|         state = [mx.array(2)] | ||||
|  | ||||
|         @partial(mx.compile, inputs=state) | ||||
|         def test_state(x): | ||||
|             x = x + state[0] | ||||
|             return x | ||||
|  | ||||
|         out = test_state(mx.array(1)) | ||||
|         self.assertEqual(out.item(), 3) | ||||
|         state[0] = mx.array(3) | ||||
|         out = test_state(mx.array(1)) | ||||
|         self.assertEqual(out.item(), 4) | ||||
|  | ||||
|         # Capture tuple of list | ||||
|         state = ([mx.array(2)],) | ||||
|  | ||||
|         @partial(mx.compile, inputs=state) | ||||
|         def test_state(x): | ||||
|             x = x + state[0][0] | ||||
|             return x | ||||
|  | ||||
|         out = test_state(mx.array(1)) | ||||
|         self.assertEqual(out.item(), 3) | ||||
|         state[0][0] = mx.array(3) | ||||
|         out = test_state(mx.array(1)) | ||||
|         self.assertEqual(out.item(), 4) | ||||
|  | ||||
|         # Test state updated inside compiled function | ||||
|         state = {} | ||||
|  | ||||
|         @partial(mx.compile, outputs=state) | ||||
|         def test_state(x): | ||||
|             state["y"] = x + 3 | ||||
|             return mx.abs(x) | ||||
|  | ||||
|         test_state(mx.array(-1)) | ||||
|         self.assertEqual(state["y"].item(), 2) | ||||
|  | ||||
|         # Test state changed inside compiled function | ||||
|         # triggers recompile | ||||
|         state = {} | ||||
|  | ||||
|         @partial(mx.compile, inputs=state, outputs=state) | ||||
|         def test_state(x): | ||||
|             y = state.get("y", mx.array(0)) | ||||
|             state["y"] = x + y | ||||
|             return x + 2 * y | ||||
|  | ||||
|         test_state(mx.array(1)) | ||||
|         self.assertEqual(state["y"].item(), 1) | ||||
|         test_state(mx.array(1)) | ||||
|         self.assertEqual(state["y"].item(), 2) | ||||
|  | ||||
|     def test_compile_rng(self): | ||||
|         @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) | ||||
|         def fun(): | ||||
|             return mx.random.uniform(shape=(10, 10)) | ||||
|  | ||||
|         self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
| @@ -24,6 +24,14 @@ class TestEval(mlx_tests.MLXTestCase): | ||||
|         y = dfun_dx(mx.array(1.0)) | ||||
|         self.assertEqual(y.item(), 6.0) | ||||
|  | ||||
|     def test_eval_mixed(self): | ||||
|         x = mx.array(1) + 1 + 1 | ||||
|         y = 0 | ||||
|         z = "hello" | ||||
|         state = [x, y, z] | ||||
|         mx.eval(state) | ||||
|         self.assertEqual(x.item(), 3) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
| @@ -130,6 +130,11 @@ class TestBase(mlx_tests.MLXTestCase): | ||||
|                 ] | ||||
|             ) | ||||
|  | ||||
|     def test_module_state(self): | ||||
|         m = nn.Linear(10, 1) | ||||
|         m.state["hello"] = "world" | ||||
|         self.assertEqual(m.state["hello"], "world") | ||||
|  | ||||
|  | ||||
| class TestLayers(mlx_tests.MLXTestCase): | ||||
|     def test_identity(self): | ||||
|   | ||||
| @@ -2,47 +2,209 @@ | ||||
|  | ||||
| import inspect | ||||
| import unittest | ||||
| from functools import partial | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| import mlx.optimizers as opt | ||||
| import mlx.utils | ||||
| import mlx_tests | ||||
| from mlx.utils import tree_flatten, tree_map | ||||
|  | ||||
|  | ||||
| def get_all_optimizers(): | ||||
|     classes = dict() | ||||
|     for name, obj in inspect.getmembers(opt): | ||||
|         if inspect.isclass(obj): | ||||
|             if obj.__name__ not in ["OptimizerState", "Optimizer"]: | ||||
|             if obj.__name__ not in ["Optimizer"]: | ||||
|                 classes[name] = obj | ||||
|     return classes | ||||
|  | ||||
|  | ||||
| def tree_equal(fn, *args): | ||||
|     return all(v for _, v in tree_flatten(tree_map(fn, *args))) | ||||
|  | ||||
|  | ||||
| optimizers_dict = get_all_optimizers() | ||||
|  | ||||
|  | ||||
| class TestOptimizers(mlx_tests.MLXTestCase): | ||||
|     def test_optimizer_state(self): | ||||
|         optim = opt.SGD(0.1) | ||||
|         optim.state["hello"] = "world" | ||||
|         self.assertEqual(optim.state["hello"], "world") | ||||
|  | ||||
|         optim.state = {0: 1} | ||||
|         self.assertEqual(optim.state, {0: 1}) | ||||
|  | ||||
|     def test_optimizers(self): | ||||
|         params = { | ||||
|             "first": [mx.zeros((10,)), mx.zeros((1,))], | ||||
|             "second": mx.zeros((1,)), | ||||
|         } | ||||
|         grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params) | ||||
|         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||
|  | ||||
|         for optim_class in optimizers_dict.values(): | ||||
|             optim = optim_class(0.1) | ||||
|             update = optim.apply_gradients(grads, params) | ||||
|             mx.eval(update) | ||||
|             equal_shape = mlx.utils.tree_map( | ||||
|                 lambda x, y: x.shape == y.shape, params, update | ||||
|             ) | ||||
|             equal_shape = tree_map(lambda x, y: x.shape == y.shape, params, update) | ||||
|             all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape)) | ||||
|             self.assertTrue(all_equal) | ||||
|  | ||||
|     def test_types_conserved(self): | ||||
|         params = {"w": mx.ones((5, 5), mx.float16)} | ||||
|         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||
|         for optim_class in optimizers_dict.values(): | ||||
|             optim = optim_class(0.1) | ||||
|             update = optim.apply_gradients(grads, params) | ||||
|             self.assertEqual(update["w"].dtype, mx.float16) | ||||
|  | ||||
|     def test_sgd(self): | ||||
|         params = { | ||||
|             "first": [mx.zeros((10,)), mx.zeros((1,))], | ||||
|             "second": mx.zeros((1,)), | ||||
|         } | ||||
|         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||
|  | ||||
|         # Explicit init | ||||
|         optim = opt.SGD(learning_rate=1e-2, momentum=0.9) | ||||
|         optim.init(params) | ||||
|         self.assertTrue( | ||||
|             tree_equal( | ||||
|                 lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), | ||||
|                 params, | ||||
|                 optim.state, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         # Implicit init | ||||
|         optim = opt.SGD(learning_rate=1e-2, momentum=0.9) | ||||
|         optim.apply_gradients(grads, params) | ||||
|         self.assertTrue( | ||||
|             tree_equal(lambda g, s: mx.array_equal(s["v"], g), grads, optim.state) | ||||
|         ) | ||||
|  | ||||
|     def test_rmsprop(self): | ||||
|         params = { | ||||
|             "first": [mx.zeros((10,)), mx.zeros((1,))], | ||||
|             "second": mx.zeros((1,)), | ||||
|         } | ||||
|         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||
|  | ||||
|         # Explicit init | ||||
|         optim = opt.RMSprop(learning_rate=1e-2) | ||||
|         optim.init(params) | ||||
|         self.assertTrue( | ||||
|             tree_equal( | ||||
|                 lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), | ||||
|                 params, | ||||
|                 optim.state, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         # Implicit init | ||||
|         alpha = 0.99 | ||||
|         optim = opt.RMSprop(learning_rate=1e-2, alpha=alpha) | ||||
|         optim.apply_gradients(grads, params) | ||||
|         self.assertTrue( | ||||
|             tree_equal( | ||||
|                 lambda g, s: mx.allclose(s["v"], (1 - alpha) * g), grads, optim.state | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     def test_adagrad(self): | ||||
|         params = { | ||||
|             "first": [mx.zeros((10,)), mx.zeros((1,))], | ||||
|             "second": mx.zeros((1,)), | ||||
|         } | ||||
|         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||
|  | ||||
|         # Explicit init | ||||
|         optim = opt.Adagrad(learning_rate=1e-2) | ||||
|         optim.init(params) | ||||
|         self.assertTrue( | ||||
|             tree_equal( | ||||
|                 lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), | ||||
|                 params, | ||||
|                 optim.state, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     def test_adadelta(self): | ||||
|         params = { | ||||
|             "first": [mx.zeros((10,)), mx.zeros((1,))], | ||||
|             "second": mx.zeros((1,)), | ||||
|         } | ||||
|         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||
|  | ||||
|         # Explicit init | ||||
|         optim = opt.AdaDelta(learning_rate=1e-2) | ||||
|         optim.init(params) | ||||
|         self.assertTrue( | ||||
|             tree_equal( | ||||
|                 lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), | ||||
|                 params, | ||||
|                 optim.state, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             tree_equal( | ||||
|                 lambda p, s: mx.array_equal(s["u"], mx.zeros_like(p)), | ||||
|                 params, | ||||
|                 optim.state, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     def test_adam(self): | ||||
|         params = { | ||||
|             "first": [mx.zeros((10,)), mx.zeros((1,))], | ||||
|             "second": mx.zeros((1,)), | ||||
|         } | ||||
|         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||
|  | ||||
|         # Explicit init | ||||
|         for optimizer in [opt.Adam, opt.AdamW, opt.Adamax]: | ||||
|             optim = optimizer(learning_rate=1e-2) | ||||
|             optim.init(params) | ||||
|             self.assertTrue( | ||||
|                 tree_equal( | ||||
|                     lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), | ||||
|                     params, | ||||
|                     optim.state, | ||||
|                 ) | ||||
|             ) | ||||
|             self.assertTrue( | ||||
|                 tree_equal( | ||||
|                     lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)), | ||||
|                     params, | ||||
|                     optim.state, | ||||
|                 ) | ||||
|             ) | ||||
|  | ||||
|     def test_lion(self): | ||||
|         params = { | ||||
|             "first": [mx.zeros((10,)), mx.zeros((1,))], | ||||
|             "second": mx.zeros((1,)), | ||||
|         } | ||||
|         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||
|  | ||||
|         # Explicit init | ||||
|         optim = opt.Lion(learning_rate=1e-2) | ||||
|         optim.init(params) | ||||
|         self.assertTrue( | ||||
|             tree_equal( | ||||
|                 lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)), | ||||
|                 params, | ||||
|                 optim.state, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     def test_adafactor(self): | ||||
|         x = mx.zeros((5, 5)) | ||||
|         grad = mx.ones_like(x) | ||||
|         optimizer = opt.Adafactor() | ||||
|         optimizer.init(x) | ||||
|         for _ in range(2): | ||||
|             xp = optimizer.apply_single(grad, x, optimizer.state) | ||||
|             self.assertEqual(xp.dtype, x.dtype) | ||||
| @@ -51,12 +213,86 @@ class TestOptimizers(mlx_tests.MLXTestCase): | ||||
|         x = mx.zeros((5, 5), mx.float16) | ||||
|         grad = mx.ones_like(x) | ||||
|         optimizer = opt.Adafactor() | ||||
|         optimizer.init(x) | ||||
|         for _ in range(2): | ||||
|             xp = optimizer.apply_single(grad, x, optimizer.state) | ||||
|             self.assertEqual(xp.dtype, x.dtype) | ||||
|             self.assertEqual(xp.shape, x.shape) | ||||
|         self.assertEqual(optimizer.state["step"], 2) | ||||
|  | ||||
|     def test_compiled_optimizer(self): | ||||
|         model = nn.Linear(10, 10) | ||||
|         x = mx.random.uniform(shape=(2, 10)) | ||||
|         optim = opt.SGD(learning_rate=1e-2, momentum=0.9) | ||||
|  | ||||
|         orig_params = model.parameters() | ||||
|  | ||||
|         def loss(model, x): | ||||
|             return model(x).sum() | ||||
|  | ||||
|         # Uncompiled version | ||||
|         def step(x): | ||||
|             _, grad = nn.value_and_grad(model, loss)(model, x) | ||||
|             optim.update(model, grad) | ||||
|  | ||||
|         step(x) | ||||
|         uncompiled_params = model.parameters() | ||||
|  | ||||
|         # Pure version | ||||
|         def loss(params, x): | ||||
|             model.update(params) | ||||
|             return model(x).sum() | ||||
|  | ||||
|         model.update(orig_params) | ||||
|         optim = opt.SGD(learning_rate=1e-2, momentum=0.9) | ||||
|  | ||||
|         @mx.compile | ||||
|         def step(params, opt_state, x): | ||||
|             grad = mx.grad(loss)(params, x) | ||||
|             optim.state = opt_state | ||||
|             params = optim.apply_gradients(grad, params) | ||||
|             return params, optim.state | ||||
|  | ||||
|         optim.init(model.parameters()) | ||||
|         pure_params, _ = step(model.parameters(), optim.state, x) | ||||
|         self.assertTrue(mx.allclose(pure_params["weight"], uncompiled_params["weight"])) | ||||
|         self.assertTrue(mx.allclose(pure_params["bias"], uncompiled_params["bias"])) | ||||
|  | ||||
|         # Impure version | ||||
|         def loss(model, x): | ||||
|             return model(x).sum() | ||||
|  | ||||
|         model.update(orig_params) | ||||
|         optim = opt.SGD(learning_rate=1e-2, momentum=0.9) | ||||
|         state = [model.state, optim.state] | ||||
|  | ||||
|         @partial(mx.compile, inputs=state, outputs=state) | ||||
|         def step(x): | ||||
|             _, grad = nn.value_and_grad(model, loss)(model, x) | ||||
|             optim.update(model, grad) | ||||
|  | ||||
|         step(x) | ||||
|         impure_params = model.parameters() | ||||
|         self.assertTrue( | ||||
|             mx.allclose(impure_params["weight"], uncompiled_params["weight"]) | ||||
|         ) | ||||
|         self.assertTrue(mx.allclose(impure_params["bias"], uncompiled_params["bias"])) | ||||
|  | ||||
|     def test_update_lr_compiled(self): | ||||
|         params = {"w": mx.ones((5, 5))} | ||||
|         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||
|         optim = opt.SGD(-1.0) | ||||
|  | ||||
|         @partial(mx.compile, inputs=optim.state) | ||||
|         def update(grads): | ||||
|             return optim.apply_gradients(grads, params) | ||||
|  | ||||
|         result = update(grads) | ||||
|         self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 2.0))) | ||||
|         optim.learning_rate = -2.0 | ||||
|         result = update(grads) | ||||
|         self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0))) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun