mlx/python/tests/test_compile.py
Awni Hannun 1b97b2958b
Compile with capture (#629)
* Simple kernel generation

* Remove the generate kernel from graph_utils

* fix multi-output with compile

* fuse with stopgrad

* v1 input, output capture in compile

* cleanup tree update with visitor update

* nit

* remove todo

* state for model, optional explicit init and more pure optimizer steps

* move learning rate to state

* add lr to opt state, some fixes in capture

* fix optim

* update tuple of containers as well

* fix stream for compiled output

* rng state for compile

* nit

* updates and comments

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-07 17:29:22 -08:00

387 lines
11 KiB
Python

# Copyright © 2023-2024 Apple Inc.
import io
import unittest
from functools import partial
import mlx.core as mx
import mlx_tests
class TestCompile(mlx_tests.MLXTestCase):
def test_simple_compile(self):
def fun(x, y):
return x + y
compiled_fn = mx.compile(fun)
compiled_fn = mx.compile(fun)
x = mx.array(1.0)
y = mx.array(1.0)
out = compiled_fn(x, y)
self.assertEqual(out.item(), 2.0)
# Try again
out = compiled_fn(x, y)
self.assertEqual(out.item(), 2.0)
# Change sizes
x = mx.array([1.0, 2.0])
out = compiled_fn(x, y)
self.assertTrue(mx.array_equal(out, mx.array([2.0, 3.0])))
y = mx.array([1.0, 2.0])
out = compiled_fn(x, y)
self.assertTrue(mx.array_equal(out, mx.array([2.0, 4.0])))
# Change types
x = mx.array([1, 2], mx.int32)
y = mx.array([1, 2], mx.int32)
out = compiled_fn(x, y)
self.assertEqual(out.dtype, mx.int32)
self.assertTrue(mx.array_equal(out, mx.array([2, 4])))
def test_compile_grad(self):
def loss_fn(x):
return mx.exp(x).sum()
grad_fn = mx.grad(loss_fn)
x = mx.array([0.5, -0.5, 1.2])
dfdx = grad_fn(x)
compile_grad_fn = mx.compile(grad_fn)
c_dfdx = grad_fn(x)
self.assertTrue(mx.allclose(c_dfdx, dfdx))
# Run it again without calling compile
c_dfdx = compile_grad_fn(x)
self.assertTrue(mx.allclose(c_dfdx, dfdx))
# Run it again with calling compile
c_dfdx = mx.compile(grad_fn)(x)
self.assertTrue(mx.allclose(c_dfdx, dfdx))
# Value and grad
def loss_fn(x):
return mx.exp(x).sum(), mx.sin(x)
val_and_grad_fn = mx.value_and_grad(loss_fn)
(loss, val), dfdx = val_and_grad_fn(x)
(c_loss, c_val), c_dfdx = mx.compile(val_and_grad_fn)(x)
self.assertTrue(mx.allclose(c_dfdx, dfdx))
self.assertTrue(mx.allclose(c_loss, loss))
self.assertTrue(mx.allclose(c_val, val))
def test_compile_inputs_with_primitives(self):
x = mx.array([1, 2, 3])
y = mx.array([1, 2, 3])
for _ in range(5):
x = x + y
y = y + 1
def fun(x, y):
return x * y
out = fun(x, y)
x = mx.array([1, 2, 3])
y = mx.array([1, 2, 3])
for _ in range(5):
x = x + y
y = y + 1
c_out = mx.compile(fun)(x, y)
self.assertTrue(mx.array_equal(out, c_out))
# Try again
c_out = mx.compile(fun)(x, y)
self.assertTrue(mx.array_equal(out, c_out))
def test_compile_with_closure(self):
x = mx.array(1)
def closure(y):
return x + y
compiled = mx.compile(closure)
out = compiled(mx.array(1))
self.assertEqual(out.item(), 2)
# Try again
out = compiled(mx.array(1))
self.assertEqual(out.item(), 2)
# Change the shape of the enclosed variable
x = mx.array([1, 2])
out = compiled(mx.array(1))
# We still get the original input (closures are not updated)
self.assertEqual(out.item(), 2)
# Try with a tree of enclosed variables
x = {"a": mx.array(1), "b": mx.array(2)}
def closure(y):
return x["a"] + y + x["b"]
compiled = mx.compile(closure)
out = compiled(mx.array(1))
self.assertEqual(out.item(), 4)
# Change the shape of one input
x["a"] = mx.array([4, 5])
out = compiled(mx.array(1))
self.assertEqual(out.item(), 4)
x["b"] = mx.array([-6, -8])
out = compiled(mx.array(1))
self.assertEqual(out.item(), 4)
# Enclosed variable is not evaluated yet
x = mx.array(1)
x = x + x
def closure(y):
return x + y
compiled = mx.compile(closure)
out = compiled(mx.array(2))
self.assertEqual(out.item(), 4)
# And again
out = compiled(mx.array(2))
self.assertEqual(out.item(), 4)
def test_function_creates_array(self):
def fun(x):
return x + mx.array(1)
cfun = mx.compile(fun)
out = cfun(mx.array(3))
self.assertEqual(out.item(), 4)
# And again
out = cfun(mx.array(3))
self.assertEqual(out.item(), 4)
def test_enable_disable(self):
def fun(x):
y = x + 1
z = x + 1
return y + z
def count_prims(outputs):
buf = io.StringIO()
mx.export_to_dot(buf, outputs)
buf.seek(0)
return len([l for l in buf.read().split() if "label" in l])
x = mx.array(1.0)
cfun = mx.compile(fun)
n_compiled = count_prims(cfun(x))
# Check disabled
mx.disable_compile()
n_uncompiled = count_prims(cfun(x))
self.assertTrue(n_compiled < n_uncompiled)
# Check renabled
mx.enable_compile()
n_enable_compiled = count_prims(cfun(x))
self.assertEqual(n_compiled, n_enable_compiled)
def test_compile_two_input_grad(self):
def loss(w, x):
y = x * w
return (y * mx.exp(y)).sum()
x = mx.array([1.0, 0.5, 2.0, -0.5])
w = mx.array([-1.0, 0.3, 1.0, -0.9])
expected_grad = mx.grad(loss)(w, x)
compiled_grad = mx.compile(mx.grad(loss))(w, x)
self.assertTrue(mx.allclose(expected_grad, compiled_grad))
def test_vmap_compiled(self):
def simple_unary(x):
return -mx.exp(x)
x = mx.array([[1.0, 2.0], [2.0, 3.0]])
expected_out = mx.vmap(simple_unary)(x)
out = mx.vmap(mx.compile(simple_unary))(x)
self.assertTrue(mx.allclose(expected_out, out))
def simple_binary(x, y):
return mx.abs(mx.exp(x + y) + y)
x = mx.array([[1.0, -3.0], [0.5, -0.5]])
y = mx.array([[2.0, -1.0], [0.25, -0.25]])
expected_out = mx.vmap(simple_binary)(x, y)
out = mx.vmap(mx.compile(simple_binary))(x, y)
self.assertTrue(mx.allclose(expected_out, out))
expected_out = mx.vmap(simple_binary, in_axes=(0, 1))(x, y)
out = mx.vmap(mx.compile(simple_binary), in_axes=(0, 1))(x, y)
self.assertTrue(mx.allclose(expected_out, out))
y = mx.array([0.25, -0.25])
expected_out = mx.vmap(simple_binary, in_axes=(0, None))(x, y)
out = mx.vmap(mx.compile(simple_binary), in_axes=(0, None))(x, y)
self.assertTrue(mx.allclose(expected_out, out))
def simple_unary_outer(x):
x = mx.abs(x)
@mx.compile
def simple_unary_inner(z):
return -mx.exp(x)
return simple_unary_inner(x)
expected_out = -mx.exp(mx.abs(x))
out = mx.vmap(simple_unary_outer)(x)
self.assertTrue(mx.allclose(expected_out, out))
def test_vjp_vjp_compiled(self):
def simple_unary(x):
return -mx.exp(x)
x = mx.array([[1.0, 2.0], [2.0, 3.0]])
y = mx.array([[1.0, 1.0], [1.0, 1.0]])
expected_out, expected_vjp_out = mx.vjp(simple_unary, (x,), (y,))
out, vjp_out = mx.vjp(mx.compile(simple_unary), (x,), (y,))
self.assertTrue(mx.allclose(expected_vjp_out[0], vjp_out[0]))
self.assertTrue(mx.allclose(expected_out[0], out[0]))
expected_out, expected_jvp_out = mx.jvp(simple_unary, (x,), (y,))
out, jvp_out = mx.jvp(mx.compile(simple_unary), (x,), (y,))
self.assertTrue(mx.allclose(expected_jvp_out[0], jvp_out[0]))
self.assertTrue(mx.allclose(expected_out[0], out[0]))
def simple_binary(x, y):
return mx.abs(mx.exp(x + y) + y)
x = mx.array([[1.0, -3.0], [0.5, -0.5]])
y = mx.array([[2.0, -1.0], [0.25, -0.25]])
cotans = mx.ones_like(x)
expected_out, expected_vjp_out = mx.vjp(simple_binary, (x, y), (cotans,))
out, vjp_out = mx.vjp(mx.compile(simple_binary), (x, y), (cotans,))
self.assertTrue(mx.allclose(expected_out[0], out[0]))
self.assertTrue(mx.allclose(expected_vjp_out[0], vjp_out[0]))
self.assertTrue(mx.allclose(expected_vjp_out[1], vjp_out[1]))
tans = (mx.ones_like(x), mx.ones_like(y))
expected_out, expected_jvp_out = mx.jvp(simple_binary, (x, y), tans)
out, jvp_out = mx.jvp(mx.compile(simple_binary), (x, y), tans)
self.assertTrue(mx.allclose(expected_jvp_out[0], jvp_out[0]))
self.assertTrue(mx.allclose(expected_out[0], out[0]))
def test_transform_over_eval_compiled(self):
def outer(x):
y = mx.exp(mx.abs(x))
mx.eval(y)
return y.sum()
x = mx.array([2.0, -1.0, 0.5])
dfdx = mx.grad(outer)(x)
@mx.compile
def simple_unary(x):
return mx.exp(mx.abs(x))
def outer(x):
y = simple_unary(x)
mx.eval(y)
return y.sum()
cdfdx = mx.grad(outer)(x)
self.assertTrue(mx.allclose(dfdx, cdfdx))
def test_compile_capture(self):
# Test update captured state outside compiled function
state = {"y": mx.array(2)}
@partial(mx.compile, inputs=state)
def test_state(x):
x = x + state["y"]
return x
test_state(mx.array(1))
# Check the state is unchanged
self.assertEqual(state["y"], 2)
# Check the udpated state is used
state["y"] = mx.array(3)
out = test_state(mx.array(1))
self.assertEqual(out.item(), 4)
# Capture list
state = [mx.array(2)]
@partial(mx.compile, inputs=state)
def test_state(x):
x = x + state[0]
return x
out = test_state(mx.array(1))
self.assertEqual(out.item(), 3)
state[0] = mx.array(3)
out = test_state(mx.array(1))
self.assertEqual(out.item(), 4)
# Capture tuple of list
state = ([mx.array(2)],)
@partial(mx.compile, inputs=state)
def test_state(x):
x = x + state[0][0]
return x
out = test_state(mx.array(1))
self.assertEqual(out.item(), 3)
state[0][0] = mx.array(3)
out = test_state(mx.array(1))
self.assertEqual(out.item(), 4)
# Test state updated inside compiled function
state = {}
@partial(mx.compile, outputs=state)
def test_state(x):
state["y"] = x + 3
return mx.abs(x)
test_state(mx.array(-1))
self.assertEqual(state["y"].item(), 2)
# Test state changed inside compiled function
# triggers recompile
state = {}
@partial(mx.compile, inputs=state, outputs=state)
def test_state(x):
y = state.get("y", mx.array(0))
state["y"] = x + y
return x + 2 * y
test_state(mx.array(1))
self.assertEqual(state["y"].item(), 1)
test_state(mx.array(1))
self.assertEqual(state["y"].item(), 2)
def test_compile_rng(self):
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def fun():
return mx.random.uniform(shape=(10, 10))
self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2))
if __name__ == "__main__":
unittest.main()