diff --git a/python/mlx/nn/utils.py b/python/mlx/nn/utils.py index f651ce92e..7035450e4 100644 --- a/python/mlx/nn/utils.py +++ b/python/mlx/nn/utils.py @@ -37,25 +37,23 @@ def value_and_grad(model: Module, fn: Callable): return wrapped_value_grad_fn -def checkpoint(module: Module, fn: Callable = None): - """Transform the passed callable to one that performs gradient - checkpointing with respect to the trainable parameters of the module (and - the callable's inputs). +def checkpoint(module: Module): + """Transform the passed module to one that performs gradient + checkpointing. + + The checkpointing is with respect to the module's trainable parameters and + inputs of the module's ``__call__`` function. Args: module (mlx.nn.Module): The module for whose parameters we will be performing gradient checkpointing. - fn (Callable, optional): The function to checkpoint. If not provided it - defaults to the provided module. Returns: - A callable that saves the inputs and outputs during the forward pass + The module that saves the inputs and outputs during the forward pass and recomputes all intermediate states during the backward pass. """ - if fn is None: - # Capturing module instead of module.__call__ allows someone to - # monkey-patch __call__ later on and the correct method will be used - fn = module + + fn = module.__call__ def inner_fn(params, *args, **kwargs): module.update(params) @@ -67,4 +65,9 @@ def checkpoint(module: Module, fn: Callable = None): def wrapped_checkpointed_fn(*args, **kwargs): return checkpointed_fn(module.trainable_parameters(), *args, **kwargs) - return wrapped_checkpointed_fn + class _(type(module)): + def __call__(self, *arg, **kwarg): + return wrapped_checkpointed_fn(*arg, **kwarg) + + module.__class__ = _ + return module diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 678acfd5b..43cd55177 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -1481,5 +1481,16 @@ class TestLayers(mlx_tests.MLXTestCase): ) +class TestNNUtils(mlx_tests.MLXTestCase): + + def test_checkpoint(self): + lin = nn.Linear(2, 2) + x = mx.array([0.1, 0.2]) + + expected_y = lin(x) + y = nn.utils.checkpoint(lin)(x) + self.assertTrue(mx.allclose(expected_y, y)) + + if __name__ == "__main__": unittest.main()