mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-05 07:34:42 +08:00
Compare commits
4 Commits
batch_rope
...
checkpoint
Author | SHA1 | Date | |
---|---|---|---|
![]() |
0dbe80a024 | ||
![]() |
a5827d0384 | ||
![]() |
1368bce280 | ||
![]() |
8918a437bb |
@@ -173,6 +173,7 @@ In detail:
|
||||
:toctree: _autosummary
|
||||
|
||||
value_and_grad
|
||||
checkpoint
|
||||
|
||||
.. toctree::
|
||||
|
||||
|
@@ -17,3 +17,4 @@ Transforms
|
||||
jvp
|
||||
vjp
|
||||
vmap
|
||||
checkpoint
|
||||
|
@@ -2,4 +2,4 @@
|
||||
|
||||
from mlx.nn import init, losses
|
||||
from mlx.nn.layers import *
|
||||
from mlx.nn.utils import value_and_grad
|
||||
from mlx.nn.utils import checkpoint, value_and_grad
|
||||
|
@@ -137,7 +137,10 @@ class Module(dict):
|
||||
super(Module, self).__getattribute__(key)
|
||||
|
||||
def __setattr__(self, key: str, val: Any):
|
||||
if isinstance(val, (mx.array, dict, list, tuple)):
|
||||
# Allow setter properties to pass through to base class
|
||||
prop = vars(self.__class__).get(key, None)
|
||||
is_prop = isinstance(prop, property) and prop.fset is not None
|
||||
if not is_prop and isinstance(val, (mx.array, dict, list, tuple)):
|
||||
self[key] = val
|
||||
else:
|
||||
super(Module, self).__setattr__(key, val)
|
||||
|
@@ -15,7 +15,7 @@ def value_and_grad(model: Module, fn: Callable):
|
||||
|
||||
Args:
|
||||
model (mlx.nn.Module): The model whose trainable parameters to compute
|
||||
gradients for
|
||||
gradients for
|
||||
fn (Callable): The scalar function to compute gradients for
|
||||
|
||||
Returns:
|
||||
@@ -37,34 +37,46 @@ def value_and_grad(model: Module, fn: Callable):
|
||||
return wrapped_value_grad_fn
|
||||
|
||||
|
||||
def checkpoint(module: Module, fn: Callable = None):
|
||||
"""Transform the passed callable to one that performs gradient
|
||||
checkpointing with respect to the trainable parameters of the module (and
|
||||
the callable's inputs).
|
||||
def checkpoint(module: Module):
|
||||
"""Transform the passed module to one that performs gradient checkpointing.
|
||||
|
||||
The checkpointing is with respect to the module's trainable parameters and
|
||||
inputs of the module's ``__call__`` function.
|
||||
|
||||
Args:
|
||||
module (mlx.nn.Module): The module for whose parameters we will be
|
||||
performing gradient checkpointing.
|
||||
fn (Callable, optional): The function to checkpoint. If not provided it
|
||||
defaults to the provided module.
|
||||
module (mlx.nn.Module): The module for which we will perform gradient
|
||||
checkpointing.
|
||||
|
||||
Returns:
|
||||
A callable that saves the inputs and outputs during the forward pass
|
||||
A new module that saves the inputs and outputs during the forward pass
|
||||
and recomputes all intermediate states during the backward pass.
|
||||
"""
|
||||
if fn is None:
|
||||
# Capturing module instead of module.__call__ allows someone to
|
||||
# monkey-patch __call__ later on and the correct method will be used
|
||||
fn = module
|
||||
|
||||
t = type(module)
|
||||
cp_name = f"__checkpointed_{id(t)}__"
|
||||
cp_class = globals().get(cp_name, None)
|
||||
if cp_class is None:
|
||||
|
||||
def init(self):
|
||||
pass
|
||||
|
||||
def call(self, *args, **kwargs):
|
||||
return self.__checkpointed_call__(
|
||||
self.trainable_parameters(), *args, **kwargs
|
||||
)
|
||||
|
||||
cp_class = type(t.__name__, (t,), {})
|
||||
cp_class.__init__ = init
|
||||
cp_class.__call__ = call
|
||||
globals()[cp_name] = cp_class
|
||||
|
||||
cp_module = cp_class()
|
||||
cp_module.__dict__.update(module.__dict__)
|
||||
super(Module, cp_module).update(module.state)
|
||||
|
||||
def inner_fn(params, *args, **kwargs):
|
||||
module.update(params)
|
||||
return fn(*args, **kwargs)
|
||||
return module(*args, **kwargs)
|
||||
|
||||
checkpointed_fn = mx.checkpoint(inner_fn)
|
||||
|
||||
@wraps(fn)
|
||||
def wrapped_checkpointed_fn(*args, **kwargs):
|
||||
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
|
||||
|
||||
return wrapped_checkpointed_fn
|
||||
cp_module.__checkpointed_call__ = mx.checkpoint(inner_fn)
|
||||
return cp_module
|
||||
|
@@ -888,7 +888,24 @@ void init_transforms(py::module_& m) {
|
||||
m.def(
|
||||
"checkpoint",
|
||||
[](py::function fun) { return py::cpp_function(PyCheckpointedFun{fun}); },
|
||||
"fun"_a);
|
||||
"fun"_a,
|
||||
R"pbdoc(
|
||||
checkpoint(fun: function) -> function
|
||||
|
||||
Returns a gradient checkpointed function.
|
||||
|
||||
The checkpointed function produces the same output as the input
|
||||
``fun`` but recomputes all intermediate states during the gradient
|
||||
computation (vjp) rather than storing them.
|
||||
|
||||
Use the checkpoint transformation to reduce memory consumption at the
|
||||
cost of increased computation.
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a variable number of :class:`array` or trees of :class:`array`.
|
||||
)pbdoc");
|
||||
|
||||
// Register static Python object cleanup before the interpreter exits
|
||||
auto atexit = py::module_::import("atexit");
|
||||
|
@@ -1481,5 +1481,28 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestNNUtils(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_checkpoint(self):
|
||||
lin = nn.Linear(2, 2)
|
||||
x = mx.array([0.1, 0.2])
|
||||
|
||||
lin.my_attr = "hello"
|
||||
|
||||
expected_y = lin(x)
|
||||
clin = nn.utils.checkpoint(lin)
|
||||
y = clin(x)
|
||||
self.assertTrue(mx.allclose(expected_y, y))
|
||||
|
||||
# Check get/set attribute
|
||||
self.assertEqual(clin.my_attr, "hello")
|
||||
|
||||
clin.my_attr = "bye"
|
||||
self.assertEqual(clin.my_attr, "bye")
|
||||
|
||||
self.assertTrue(isinstance(clin, nn.Linear))
|
||||
self.assertEqual(repr(clin), repr(lin))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user