mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compile with capture (#629)
* Simple kernel generation * Remove the generate kernel from graph_utils * fix multi-output with compile * fuse with stopgrad * v1 input, output capture in compile * cleanup tree update with visitor update * nit * remove todo * state for model, optional explicit init and more pure optimizer steps * move learning rate to state * add lr to opt state, some fixes in capture * fix optim * update tuple of containers as well * fix stream for compiled output * rng state for compile * nit * updates and comments --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
import io
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
@@ -301,6 +302,85 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
cdfdx = mx.grad(outer)(x)
|
||||
self.assertTrue(mx.allclose(dfdx, cdfdx))
|
||||
|
||||
def test_compile_capture(self):
|
||||
# Test update captured state outside compiled function
|
||||
state = {"y": mx.array(2)}
|
||||
|
||||
@partial(mx.compile, inputs=state)
|
||||
def test_state(x):
|
||||
x = x + state["y"]
|
||||
return x
|
||||
|
||||
test_state(mx.array(1))
|
||||
# Check the state is unchanged
|
||||
self.assertEqual(state["y"], 2)
|
||||
|
||||
# Check the udpated state is used
|
||||
state["y"] = mx.array(3)
|
||||
out = test_state(mx.array(1))
|
||||
self.assertEqual(out.item(), 4)
|
||||
|
||||
# Capture list
|
||||
state = [mx.array(2)]
|
||||
|
||||
@partial(mx.compile, inputs=state)
|
||||
def test_state(x):
|
||||
x = x + state[0]
|
||||
return x
|
||||
|
||||
out = test_state(mx.array(1))
|
||||
self.assertEqual(out.item(), 3)
|
||||
state[0] = mx.array(3)
|
||||
out = test_state(mx.array(1))
|
||||
self.assertEqual(out.item(), 4)
|
||||
|
||||
# Capture tuple of list
|
||||
state = ([mx.array(2)],)
|
||||
|
||||
@partial(mx.compile, inputs=state)
|
||||
def test_state(x):
|
||||
x = x + state[0][0]
|
||||
return x
|
||||
|
||||
out = test_state(mx.array(1))
|
||||
self.assertEqual(out.item(), 3)
|
||||
state[0][0] = mx.array(3)
|
||||
out = test_state(mx.array(1))
|
||||
self.assertEqual(out.item(), 4)
|
||||
|
||||
# Test state updated inside compiled function
|
||||
state = {}
|
||||
|
||||
@partial(mx.compile, outputs=state)
|
||||
def test_state(x):
|
||||
state["y"] = x + 3
|
||||
return mx.abs(x)
|
||||
|
||||
test_state(mx.array(-1))
|
||||
self.assertEqual(state["y"].item(), 2)
|
||||
|
||||
# Test state changed inside compiled function
|
||||
# triggers recompile
|
||||
state = {}
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def test_state(x):
|
||||
y = state.get("y", mx.array(0))
|
||||
state["y"] = x + y
|
||||
return x + 2 * y
|
||||
|
||||
test_state(mx.array(1))
|
||||
self.assertEqual(state["y"].item(), 1)
|
||||
test_state(mx.array(1))
|
||||
self.assertEqual(state["y"].item(), 2)
|
||||
|
||||
def test_compile_rng(self):
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def fun():
|
||||
return mx.random.uniform(shape=(10, 10))
|
||||
|
||||
self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -24,6 +24,14 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
y = dfun_dx(mx.array(1.0))
|
||||
self.assertEqual(y.item(), 6.0)
|
||||
|
||||
def test_eval_mixed(self):
|
||||
x = mx.array(1) + 1 + 1
|
||||
y = 0
|
||||
z = "hello"
|
||||
state = [x, y, z]
|
||||
mx.eval(state)
|
||||
self.assertEqual(x.item(), 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -130,6 +130,11 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
]
|
||||
)
|
||||
|
||||
def test_module_state(self):
|
||||
m = nn.Linear(10, 1)
|
||||
m.state["hello"] = "world"
|
||||
self.assertEqual(m.state["hello"], "world")
|
||||
|
||||
|
||||
class TestLayers(mlx_tests.MLXTestCase):
|
||||
def test_identity(self):
|
||||
|
||||
@@ -2,47 +2,209 @@
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as opt
|
||||
import mlx.utils
|
||||
import mlx_tests
|
||||
from mlx.utils import tree_flatten, tree_map
|
||||
|
||||
|
||||
def get_all_optimizers():
|
||||
classes = dict()
|
||||
for name, obj in inspect.getmembers(opt):
|
||||
if inspect.isclass(obj):
|
||||
if obj.__name__ not in ["OptimizerState", "Optimizer"]:
|
||||
if obj.__name__ not in ["Optimizer"]:
|
||||
classes[name] = obj
|
||||
return classes
|
||||
|
||||
|
||||
def tree_equal(fn, *args):
|
||||
return all(v for _, v in tree_flatten(tree_map(fn, *args)))
|
||||
|
||||
|
||||
optimizers_dict = get_all_optimizers()
|
||||
|
||||
|
||||
class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
def test_optimizer_state(self):
|
||||
optim = opt.SGD(0.1)
|
||||
optim.state["hello"] = "world"
|
||||
self.assertEqual(optim.state["hello"], "world")
|
||||
|
||||
optim.state = {0: 1}
|
||||
self.assertEqual(optim.state, {0: 1})
|
||||
|
||||
def test_optimizers(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
"second": mx.zeros((1,)),
|
||||
}
|
||||
grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params)
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
for optim_class in optimizers_dict.values():
|
||||
optim = optim_class(0.1)
|
||||
update = optim.apply_gradients(grads, params)
|
||||
mx.eval(update)
|
||||
equal_shape = mlx.utils.tree_map(
|
||||
lambda x, y: x.shape == y.shape, params, update
|
||||
)
|
||||
equal_shape = tree_map(lambda x, y: x.shape == y.shape, params, update)
|
||||
all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
|
||||
self.assertTrue(all_equal)
|
||||
|
||||
def test_types_conserved(self):
|
||||
params = {"w": mx.ones((5, 5), mx.float16)}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
for optim_class in optimizers_dict.values():
|
||||
optim = optim_class(0.1)
|
||||
update = optim.apply_gradients(grads, params)
|
||||
self.assertEqual(update["w"].dtype, mx.float16)
|
||||
|
||||
def test_sgd(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
"second": mx.zeros((1,)),
|
||||
}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
# Explicit init
|
||||
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
|
||||
optim.init(params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
|
||||
# Implicit init
|
||||
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
|
||||
optim.apply_gradients(grads, params)
|
||||
self.assertTrue(
|
||||
tree_equal(lambda g, s: mx.array_equal(s["v"], g), grads, optim.state)
|
||||
)
|
||||
|
||||
def test_rmsprop(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
"second": mx.zeros((1,)),
|
||||
}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
# Explicit init
|
||||
optim = opt.RMSprop(learning_rate=1e-2)
|
||||
optim.init(params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
|
||||
# Implicit init
|
||||
alpha = 0.99
|
||||
optim = opt.RMSprop(learning_rate=1e-2, alpha=alpha)
|
||||
optim.apply_gradients(grads, params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda g, s: mx.allclose(s["v"], (1 - alpha) * g), grads, optim.state
|
||||
)
|
||||
)
|
||||
|
||||
def test_adagrad(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
"second": mx.zeros((1,)),
|
||||
}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
# Explicit init
|
||||
optim = opt.Adagrad(learning_rate=1e-2)
|
||||
optim.init(params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
|
||||
def test_adadelta(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
"second": mx.zeros((1,)),
|
||||
}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
# Explicit init
|
||||
optim = opt.AdaDelta(learning_rate=1e-2)
|
||||
optim.init(params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["u"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
|
||||
def test_adam(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
"second": mx.zeros((1,)),
|
||||
}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
# Explicit init
|
||||
for optimizer in [opt.Adam, opt.AdamW, opt.Adamax]:
|
||||
optim = optimizer(learning_rate=1e-2)
|
||||
optim.init(params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
|
||||
def test_lion(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
"second": mx.zeros((1,)),
|
||||
}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
# Explicit init
|
||||
optim = opt.Lion(learning_rate=1e-2)
|
||||
optim.init(params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
|
||||
def test_adafactor(self):
|
||||
x = mx.zeros((5, 5))
|
||||
grad = mx.ones_like(x)
|
||||
optimizer = opt.Adafactor()
|
||||
optimizer.init(x)
|
||||
for _ in range(2):
|
||||
xp = optimizer.apply_single(grad, x, optimizer.state)
|
||||
self.assertEqual(xp.dtype, x.dtype)
|
||||
@@ -51,12 +213,86 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
x = mx.zeros((5, 5), mx.float16)
|
||||
grad = mx.ones_like(x)
|
||||
optimizer = opt.Adafactor()
|
||||
optimizer.init(x)
|
||||
for _ in range(2):
|
||||
xp = optimizer.apply_single(grad, x, optimizer.state)
|
||||
self.assertEqual(xp.dtype, x.dtype)
|
||||
self.assertEqual(xp.shape, x.shape)
|
||||
self.assertEqual(optimizer.state["step"], 2)
|
||||
|
||||
def test_compiled_optimizer(self):
|
||||
model = nn.Linear(10, 10)
|
||||
x = mx.random.uniform(shape=(2, 10))
|
||||
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
|
||||
|
||||
orig_params = model.parameters()
|
||||
|
||||
def loss(model, x):
|
||||
return model(x).sum()
|
||||
|
||||
# Uncompiled version
|
||||
def step(x):
|
||||
_, grad = nn.value_and_grad(model, loss)(model, x)
|
||||
optim.update(model, grad)
|
||||
|
||||
step(x)
|
||||
uncompiled_params = model.parameters()
|
||||
|
||||
# Pure version
|
||||
def loss(params, x):
|
||||
model.update(params)
|
||||
return model(x).sum()
|
||||
|
||||
model.update(orig_params)
|
||||
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
|
||||
|
||||
@mx.compile
|
||||
def step(params, opt_state, x):
|
||||
grad = mx.grad(loss)(params, x)
|
||||
optim.state = opt_state
|
||||
params = optim.apply_gradients(grad, params)
|
||||
return params, optim.state
|
||||
|
||||
optim.init(model.parameters())
|
||||
pure_params, _ = step(model.parameters(), optim.state, x)
|
||||
self.assertTrue(mx.allclose(pure_params["weight"], uncompiled_params["weight"]))
|
||||
self.assertTrue(mx.allclose(pure_params["bias"], uncompiled_params["bias"]))
|
||||
|
||||
# Impure version
|
||||
def loss(model, x):
|
||||
return model(x).sum()
|
||||
|
||||
model.update(orig_params)
|
||||
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
|
||||
state = [model.state, optim.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(x):
|
||||
_, grad = nn.value_and_grad(model, loss)(model, x)
|
||||
optim.update(model, grad)
|
||||
|
||||
step(x)
|
||||
impure_params = model.parameters()
|
||||
self.assertTrue(
|
||||
mx.allclose(impure_params["weight"], uncompiled_params["weight"])
|
||||
)
|
||||
self.assertTrue(mx.allclose(impure_params["bias"], uncompiled_params["bias"]))
|
||||
|
||||
def test_update_lr_compiled(self):
|
||||
params = {"w": mx.ones((5, 5))}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
optim = opt.SGD(-1.0)
|
||||
|
||||
@partial(mx.compile, inputs=optim.state)
|
||||
def update(grads):
|
||||
return optim.apply_gradients(grads, params)
|
||||
|
||||
result = update(grads)
|
||||
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 2.0)))
|
||||
optim.learning_rate = -2.0
|
||||
result = update(grads)
|
||||
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user