Compare commits

...

4 Commits

Author SHA1 Message Date
Awni Hannun
0dbe80a024 try again with checkpointed classes 2024-03-06 10:38:04 -08:00
Awni Hannun
a5827d0384 docs for checkpoint + a few more tests 2024-03-06 10:38:04 -08:00
Awni Hannun
1368bce280 fix tests and add setter attributes 2024-03-06 10:38:04 -08:00
Awni Hannun
8918a437bb checkpoint module's __call__ 2024-03-06 10:38:04 -08:00
7 changed files with 82 additions and 25 deletions

View File

@@ -173,6 +173,7 @@ In detail:
:toctree: _autosummary
value_and_grad
checkpoint
.. toctree::

View File

@@ -17,3 +17,4 @@ Transforms
jvp
vjp
vmap
checkpoint

View File

@@ -2,4 +2,4 @@
from mlx.nn import init, losses
from mlx.nn.layers import *
from mlx.nn.utils import value_and_grad
from mlx.nn.utils import checkpoint, value_and_grad

View File

@@ -137,7 +137,10 @@ class Module(dict):
super(Module, self).__getattribute__(key)
def __setattr__(self, key: str, val: Any):
if isinstance(val, (mx.array, dict, list, tuple)):
# Allow setter properties to pass through to base class
prop = vars(self.__class__).get(key, None)
is_prop = isinstance(prop, property) and prop.fset is not None
if not is_prop and isinstance(val, (mx.array, dict, list, tuple)):
self[key] = val
else:
super(Module, self).__setattr__(key, val)

View File

@@ -15,7 +15,7 @@ def value_and_grad(model: Module, fn: Callable):
Args:
model (mlx.nn.Module): The model whose trainable parameters to compute
gradients for
gradients for
fn (Callable): The scalar function to compute gradients for
Returns:
@@ -37,34 +37,46 @@ def value_and_grad(model: Module, fn: Callable):
return wrapped_value_grad_fn
def checkpoint(module: Module, fn: Callable = None):
"""Transform the passed callable to one that performs gradient
checkpointing with respect to the trainable parameters of the module (and
the callable's inputs).
def checkpoint(module: Module):
"""Transform the passed module to one that performs gradient checkpointing.
The checkpointing is with respect to the module's trainable parameters and
inputs of the module's ``__call__`` function.
Args:
module (mlx.nn.Module): The module for whose parameters we will be
performing gradient checkpointing.
fn (Callable, optional): The function to checkpoint. If not provided it
defaults to the provided module.
module (mlx.nn.Module): The module for which we will perform gradient
checkpointing.
Returns:
A callable that saves the inputs and outputs during the forward pass
A new module that saves the inputs and outputs during the forward pass
and recomputes all intermediate states during the backward pass.
"""
if fn is None:
# Capturing module instead of module.__call__ allows someone to
# monkey-patch __call__ later on and the correct method will be used
fn = module
t = type(module)
cp_name = f"__checkpointed_{id(t)}__"
cp_class = globals().get(cp_name, None)
if cp_class is None:
def init(self):
pass
def call(self, *args, **kwargs):
return self.__checkpointed_call__(
self.trainable_parameters(), *args, **kwargs
)
cp_class = type(t.__name__, (t,), {})
cp_class.__init__ = init
cp_class.__call__ = call
globals()[cp_name] = cp_class
cp_module = cp_class()
cp_module.__dict__.update(module.__dict__)
super(Module, cp_module).update(module.state)
def inner_fn(params, *args, **kwargs):
module.update(params)
return fn(*args, **kwargs)
return module(*args, **kwargs)
checkpointed_fn = mx.checkpoint(inner_fn)
@wraps(fn)
def wrapped_checkpointed_fn(*args, **kwargs):
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
return wrapped_checkpointed_fn
cp_module.__checkpointed_call__ = mx.checkpoint(inner_fn)
return cp_module

View File

@@ -888,7 +888,24 @@ void init_transforms(py::module_& m) {
m.def(
"checkpoint",
[](py::function fun) { return py::cpp_function(PyCheckpointedFun{fun}); },
"fun"_a);
"fun"_a,
R"pbdoc(
checkpoint(fun: function) -> function
Returns a gradient checkpointed function.
The checkpointed function produces the same output as the input
``fun`` but recomputes all intermediate states during the gradient
computation (vjp) rather than storing them.
Use the checkpoint transformation to reduce memory consumption at the
cost of increased computation.
Args:
fun (function): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns
a variable number of :class:`array` or trees of :class:`array`.
)pbdoc");
// Register static Python object cleanup before the interpreter exits
auto atexit = py::module_::import("atexit");

View File

@@ -1481,5 +1481,28 @@ class TestLayers(mlx_tests.MLXTestCase):
)
class TestNNUtils(mlx_tests.MLXTestCase):
def test_checkpoint(self):
lin = nn.Linear(2, 2)
x = mx.array([0.1, 0.2])
lin.my_attr = "hello"
expected_y = lin(x)
clin = nn.utils.checkpoint(lin)
y = clin(x)
self.assertTrue(mx.allclose(expected_y, y))
# Check get/set attribute
self.assertEqual(clin.my_attr, "hello")
clin.my_attr = "bye"
self.assertEqual(clin.my_attr, "bye")
self.assertTrue(isinstance(clin, nn.Linear))
self.assertEqual(repr(clin), repr(lin))
if __name__ == "__main__":
unittest.main()