checkpoint module's __call__

This commit is contained in:
Awni Hannun 2024-03-05 08:39:25 -08:00
parent cbefd9129e
commit 8918a437bb
2 changed files with 26 additions and 12 deletions

View File

@ -37,25 +37,23 @@ def value_and_grad(model: Module, fn: Callable):
return wrapped_value_grad_fn return wrapped_value_grad_fn
def checkpoint(module: Module, fn: Callable = None): def checkpoint(module: Module):
"""Transform the passed callable to one that performs gradient """Transform the passed module to one that performs gradient
checkpointing with respect to the trainable parameters of the module (and checkpointing.
the callable's inputs).
The checkpointing is with respect to the module's trainable parameters and
inputs of the module's ``__call__`` function.
Args: Args:
module (mlx.nn.Module): The module for whose parameters we will be module (mlx.nn.Module): The module for whose parameters we will be
performing gradient checkpointing. performing gradient checkpointing.
fn (Callable, optional): The function to checkpoint. If not provided it
defaults to the provided module.
Returns: Returns:
A callable that saves the inputs and outputs during the forward pass The module that saves the inputs and outputs during the forward pass
and recomputes all intermediate states during the backward pass. and recomputes all intermediate states during the backward pass.
""" """
if fn is None:
# Capturing module instead of module.__call__ allows someone to fn = module.__call__
# monkey-patch __call__ later on and the correct method will be used
fn = module
def inner_fn(params, *args, **kwargs): def inner_fn(params, *args, **kwargs):
module.update(params) module.update(params)
@ -67,4 +65,9 @@ def checkpoint(module: Module, fn: Callable = None):
def wrapped_checkpointed_fn(*args, **kwargs): def wrapped_checkpointed_fn(*args, **kwargs):
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs) return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
return wrapped_checkpointed_fn class _(type(module)):
def __call__(self, *arg, **kwarg):
return wrapped_checkpointed_fn(*arg, **kwarg)
module.__class__ = _
return module

View File

@ -1481,5 +1481,16 @@ class TestLayers(mlx_tests.MLXTestCase):
) )
class TestNNUtils(mlx_tests.MLXTestCase):
def test_checkpoint(self):
lin = nn.Linear(2, 2)
x = mx.array([0.1, 0.2])
expected_y = lin(x)
y = nn.utils.checkpoint(lin)(x)
self.assertTrue(mx.allclose(expected_y, y))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()