mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:38:09 +08:00
Compile with capture (#629)
* Simple kernel generation * Remove the generate kernel from graph_utils * fix multi-output with compile * fuse with stopgrad * v1 input, output capture in compile * cleanup tree update with visitor update * nit * remove todo * state for model, optional explicit init and more pure optimizer steps * move learning rate to state * add lr to opt state, some fixes in capture * fix optim * update tuple of containers as well * fix stream for compiled output * rng state for compile * nit * updates and comments --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
import io
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
@@ -301,6 +302,85 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
cdfdx = mx.grad(outer)(x)
|
||||
self.assertTrue(mx.allclose(dfdx, cdfdx))
|
||||
|
||||
def test_compile_capture(self):
|
||||
# Test update captured state outside compiled function
|
||||
state = {"y": mx.array(2)}
|
||||
|
||||
@partial(mx.compile, inputs=state)
|
||||
def test_state(x):
|
||||
x = x + state["y"]
|
||||
return x
|
||||
|
||||
test_state(mx.array(1))
|
||||
# Check the state is unchanged
|
||||
self.assertEqual(state["y"], 2)
|
||||
|
||||
# Check the udpated state is used
|
||||
state["y"] = mx.array(3)
|
||||
out = test_state(mx.array(1))
|
||||
self.assertEqual(out.item(), 4)
|
||||
|
||||
# Capture list
|
||||
state = [mx.array(2)]
|
||||
|
||||
@partial(mx.compile, inputs=state)
|
||||
def test_state(x):
|
||||
x = x + state[0]
|
||||
return x
|
||||
|
||||
out = test_state(mx.array(1))
|
||||
self.assertEqual(out.item(), 3)
|
||||
state[0] = mx.array(3)
|
||||
out = test_state(mx.array(1))
|
||||
self.assertEqual(out.item(), 4)
|
||||
|
||||
# Capture tuple of list
|
||||
state = ([mx.array(2)],)
|
||||
|
||||
@partial(mx.compile, inputs=state)
|
||||
def test_state(x):
|
||||
x = x + state[0][0]
|
||||
return x
|
||||
|
||||
out = test_state(mx.array(1))
|
||||
self.assertEqual(out.item(), 3)
|
||||
state[0][0] = mx.array(3)
|
||||
out = test_state(mx.array(1))
|
||||
self.assertEqual(out.item(), 4)
|
||||
|
||||
# Test state updated inside compiled function
|
||||
state = {}
|
||||
|
||||
@partial(mx.compile, outputs=state)
|
||||
def test_state(x):
|
||||
state["y"] = x + 3
|
||||
return mx.abs(x)
|
||||
|
||||
test_state(mx.array(-1))
|
||||
self.assertEqual(state["y"].item(), 2)
|
||||
|
||||
# Test state changed inside compiled function
|
||||
# triggers recompile
|
||||
state = {}
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def test_state(x):
|
||||
y = state.get("y", mx.array(0))
|
||||
state["y"] = x + y
|
||||
return x + 2 * y
|
||||
|
||||
test_state(mx.array(1))
|
||||
self.assertEqual(state["y"].item(), 1)
|
||||
test_state(mx.array(1))
|
||||
self.assertEqual(state["y"].item(), 2)
|
||||
|
||||
def test_compile_rng(self):
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def fun():
|
||||
return mx.random.uniform(shape=(10, 10))
|
||||
|
||||
self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user