checkpoint module's __call__

This commit is contained in:
Awni Hannun 2024-03-05 08:39:25 -08:00
parent cbefd9129e
commit 8918a437bb
2 changed files with 26 additions and 12 deletions

View File

@ -37,25 +37,23 @@ def value_and_grad(model: Module, fn: Callable):
return wrapped_value_grad_fn
def checkpoint(module: Module, fn: Callable = None):
"""Transform the passed callable to one that performs gradient
checkpointing with respect to the trainable parameters of the module (and
the callable's inputs).
def checkpoint(module: Module):
"""Transform the passed module to one that performs gradient
checkpointing.
The checkpointing is with respect to the module's trainable parameters and
inputs of the module's ``__call__`` function.
Args:
module (mlx.nn.Module): The module for whose parameters we will be
performing gradient checkpointing.
fn (Callable, optional): The function to checkpoint. If not provided it
defaults to the provided module.
Returns:
A callable that saves the inputs and outputs during the forward pass
The module that saves the inputs and outputs during the forward pass
and recomputes all intermediate states during the backward pass.
"""
if fn is None:
# Capturing module instead of module.__call__ allows someone to
# monkey-patch __call__ later on and the correct method will be used
fn = module
fn = module.__call__
def inner_fn(params, *args, **kwargs):
module.update(params)
@ -67,4 +65,9 @@ def checkpoint(module: Module, fn: Callable = None):
def wrapped_checkpointed_fn(*args, **kwargs):
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
return wrapped_checkpointed_fn
class _(type(module)):
def __call__(self, *arg, **kwarg):
return wrapped_checkpointed_fn(*arg, **kwarg)
module.__class__ = _
return module

View File

@ -1481,5 +1481,16 @@ class TestLayers(mlx_tests.MLXTestCase):
)
class TestNNUtils(mlx_tests.MLXTestCase):
def test_checkpoint(self):
lin = nn.Linear(2, 2)
x = mx.array([0.1, 0.2])
expected_y = lin(x)
y = nn.utils.checkpoint(lin)(x)
self.assertTrue(mx.allclose(expected_y, y))
if __name__ == "__main__":
unittest.main()