Compile with capture (#629)

* Simple kernel generation

* Remove the generate kernel from graph_utils

* fix multi-output with compile

* fuse with stopgrad

* v1 input, output capture in compile

* cleanup tree update with visitor update

* nit

* remove todo

* state for model, optional explicit init and more pure optimizer steps

* move learning rate to state

* add lr to opt state, some fixes in capture

* fix optim

* update tuple of containers as well

* fix stream for compiled output

* rng state for compile

* nit

* updates and comments

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-02-07 17:29:22 -08:00
committed by GitHub
parent e5e816a5ef
commit 1b97b2958b
13 changed files with 723 additions and 157 deletions

View File

@@ -2,6 +2,7 @@
import io
import unittest
from functools import partial
import mlx.core as mx
import mlx_tests
@@ -301,6 +302,85 @@ class TestCompile(mlx_tests.MLXTestCase):
cdfdx = mx.grad(outer)(x)
self.assertTrue(mx.allclose(dfdx, cdfdx))
def test_compile_capture(self):
# Test update captured state outside compiled function
state = {"y": mx.array(2)}
@partial(mx.compile, inputs=state)
def test_state(x):
x = x + state["y"]
return x
test_state(mx.array(1))
# Check the state is unchanged
self.assertEqual(state["y"], 2)
# Check the udpated state is used
state["y"] = mx.array(3)
out = test_state(mx.array(1))
self.assertEqual(out.item(), 4)
# Capture list
state = [mx.array(2)]
@partial(mx.compile, inputs=state)
def test_state(x):
x = x + state[0]
return x
out = test_state(mx.array(1))
self.assertEqual(out.item(), 3)
state[0] = mx.array(3)
out = test_state(mx.array(1))
self.assertEqual(out.item(), 4)
# Capture tuple of list
state = ([mx.array(2)],)
@partial(mx.compile, inputs=state)
def test_state(x):
x = x + state[0][0]
return x
out = test_state(mx.array(1))
self.assertEqual(out.item(), 3)
state[0][0] = mx.array(3)
out = test_state(mx.array(1))
self.assertEqual(out.item(), 4)
# Test state updated inside compiled function
state = {}
@partial(mx.compile, outputs=state)
def test_state(x):
state["y"] = x + 3
return mx.abs(x)
test_state(mx.array(-1))
self.assertEqual(state["y"].item(), 2)
# Test state changed inside compiled function
# triggers recompile
state = {}
@partial(mx.compile, inputs=state, outputs=state)
def test_state(x):
y = state.get("y", mx.array(0))
state["y"] = x + y
return x + 2 * y
test_state(mx.array(1))
self.assertEqual(state["y"].item(), 1)
test_state(mx.array(1))
self.assertEqual(state["y"].item(), 2)
def test_compile_rng(self):
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def fun():
return mx.random.uniform(shape=(10, 10))
self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2))
if __name__ == "__main__":
unittest.main()