mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-27 20:07:59 +08:00
checkpoint module's __call__
This commit is contained in:
parent
cbefd9129e
commit
8918a437bb
@ -37,25 +37,23 @@ def value_and_grad(model: Module, fn: Callable):
|
|||||||
return wrapped_value_grad_fn
|
return wrapped_value_grad_fn
|
||||||
|
|
||||||
|
|
||||||
def checkpoint(module: Module, fn: Callable = None):
|
def checkpoint(module: Module):
|
||||||
"""Transform the passed callable to one that performs gradient
|
"""Transform the passed module to one that performs gradient
|
||||||
checkpointing with respect to the trainable parameters of the module (and
|
checkpointing.
|
||||||
the callable's inputs).
|
|
||||||
|
The checkpointing is with respect to the module's trainable parameters and
|
||||||
|
inputs of the module's ``__call__`` function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module (mlx.nn.Module): The module for whose parameters we will be
|
module (mlx.nn.Module): The module for whose parameters we will be
|
||||||
performing gradient checkpointing.
|
performing gradient checkpointing.
|
||||||
fn (Callable, optional): The function to checkpoint. If not provided it
|
|
||||||
defaults to the provided module.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A callable that saves the inputs and outputs during the forward pass
|
The module that saves the inputs and outputs during the forward pass
|
||||||
and recomputes all intermediate states during the backward pass.
|
and recomputes all intermediate states during the backward pass.
|
||||||
"""
|
"""
|
||||||
if fn is None:
|
|
||||||
# Capturing module instead of module.__call__ allows someone to
|
fn = module.__call__
|
||||||
# monkey-patch __call__ later on and the correct method will be used
|
|
||||||
fn = module
|
|
||||||
|
|
||||||
def inner_fn(params, *args, **kwargs):
|
def inner_fn(params, *args, **kwargs):
|
||||||
module.update(params)
|
module.update(params)
|
||||||
@ -67,4 +65,9 @@ def checkpoint(module: Module, fn: Callable = None):
|
|||||||
def wrapped_checkpointed_fn(*args, **kwargs):
|
def wrapped_checkpointed_fn(*args, **kwargs):
|
||||||
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
|
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
|
||||||
|
|
||||||
return wrapped_checkpointed_fn
|
class _(type(module)):
|
||||||
|
def __call__(self, *arg, **kwarg):
|
||||||
|
return wrapped_checkpointed_fn(*arg, **kwarg)
|
||||||
|
|
||||||
|
module.__class__ = _
|
||||||
|
return module
|
||||||
|
@ -1481,5 +1481,16 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestNNUtils(mlx_tests.MLXTestCase):
|
||||||
|
|
||||||
|
def test_checkpoint(self):
|
||||||
|
lin = nn.Linear(2, 2)
|
||||||
|
x = mx.array([0.1, 0.2])
|
||||||
|
|
||||||
|
expected_y = lin(x)
|
||||||
|
y = nn.utils.checkpoint(lin)(x)
|
||||||
|
self.assertTrue(mx.allclose(expected_y, y))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user