mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-25 03:31:17 +08:00
checkpoint module's __call__
This commit is contained in:
parent
cbefd9129e
commit
8918a437bb
@ -37,25 +37,23 @@ def value_and_grad(model: Module, fn: Callable):
|
||||
return wrapped_value_grad_fn
|
||||
|
||||
|
||||
def checkpoint(module: Module, fn: Callable = None):
|
||||
"""Transform the passed callable to one that performs gradient
|
||||
checkpointing with respect to the trainable parameters of the module (and
|
||||
the callable's inputs).
|
||||
def checkpoint(module: Module):
|
||||
"""Transform the passed module to one that performs gradient
|
||||
checkpointing.
|
||||
|
||||
The checkpointing is with respect to the module's trainable parameters and
|
||||
inputs of the module's ``__call__`` function.
|
||||
|
||||
Args:
|
||||
module (mlx.nn.Module): The module for whose parameters we will be
|
||||
performing gradient checkpointing.
|
||||
fn (Callable, optional): The function to checkpoint. If not provided it
|
||||
defaults to the provided module.
|
||||
|
||||
Returns:
|
||||
A callable that saves the inputs and outputs during the forward pass
|
||||
The module that saves the inputs and outputs during the forward pass
|
||||
and recomputes all intermediate states during the backward pass.
|
||||
"""
|
||||
if fn is None:
|
||||
# Capturing module instead of module.__call__ allows someone to
|
||||
# monkey-patch __call__ later on and the correct method will be used
|
||||
fn = module
|
||||
|
||||
fn = module.__call__
|
||||
|
||||
def inner_fn(params, *args, **kwargs):
|
||||
module.update(params)
|
||||
@ -67,4 +65,9 @@ def checkpoint(module: Module, fn: Callable = None):
|
||||
def wrapped_checkpointed_fn(*args, **kwargs):
|
||||
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
|
||||
|
||||
return wrapped_checkpointed_fn
|
||||
class _(type(module)):
|
||||
def __call__(self, *arg, **kwarg):
|
||||
return wrapped_checkpointed_fn(*arg, **kwarg)
|
||||
|
||||
module.__class__ = _
|
||||
return module
|
||||
|
@ -1481,5 +1481,16 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestNNUtils(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_checkpoint(self):
|
||||
lin = nn.Linear(2, 2)
|
||||
x = mx.array([0.1, 0.2])
|
||||
|
||||
expected_y = lin(x)
|
||||
y = nn.utils.checkpoint(lin)(x)
|
||||
self.assertTrue(mx.allclose(expected_y, y))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user