mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	checkpoint module's __call__
This commit is contained in:
		@@ -37,25 +37,23 @@ def value_and_grad(model: Module, fn: Callable):
 | 
			
		||||
    return wrapped_value_grad_fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def checkpoint(module: Module, fn: Callable = None):
 | 
			
		||||
    """Transform the passed callable to one that performs gradient
 | 
			
		||||
    checkpointing with respect to the trainable parameters of the module (and
 | 
			
		||||
    the callable's inputs).
 | 
			
		||||
def checkpoint(module: Module):
 | 
			
		||||
    """Transform the passed module to one that performs gradient
 | 
			
		||||
    checkpointing.
 | 
			
		||||
 | 
			
		||||
    The checkpointing is with respect to the module's trainable parameters and
 | 
			
		||||
    inputs of the module's ``__call__`` function.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        module (mlx.nn.Module): The module for whose parameters we will be
 | 
			
		||||
            performing gradient checkpointing.
 | 
			
		||||
        fn (Callable, optional): The function to checkpoint. If not provided it
 | 
			
		||||
            defaults to the provided module.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        A callable that saves the inputs and outputs during the forward pass
 | 
			
		||||
        The module that saves the inputs and outputs during the forward pass
 | 
			
		||||
        and recomputes all intermediate states during the backward pass.
 | 
			
		||||
    """
 | 
			
		||||
    if fn is None:
 | 
			
		||||
        # Capturing module instead of module.__call__ allows someone to
 | 
			
		||||
        # monkey-patch __call__ later on and the correct method will be used
 | 
			
		||||
        fn = module
 | 
			
		||||
 | 
			
		||||
    fn = module.__call__
 | 
			
		||||
 | 
			
		||||
    def inner_fn(params, *args, **kwargs):
 | 
			
		||||
        module.update(params)
 | 
			
		||||
@@ -67,4 +65,9 @@ def checkpoint(module: Module, fn: Callable = None):
 | 
			
		||||
    def wrapped_checkpointed_fn(*args, **kwargs):
 | 
			
		||||
        return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    return wrapped_checkpointed_fn
 | 
			
		||||
    class _(type(module)):
 | 
			
		||||
        def __call__(self, *arg, **kwarg):
 | 
			
		||||
            return wrapped_checkpointed_fn(*arg, **kwarg)
 | 
			
		||||
 | 
			
		||||
    module.__class__ = _
 | 
			
		||||
    return module
 | 
			
		||||
 
 | 
			
		||||
@@ -1481,5 +1481,16 @@ class TestLayers(mlx_tests.MLXTestCase):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestNNUtils(mlx_tests.MLXTestCase):
 | 
			
		||||
 | 
			
		||||
    def test_checkpoint(self):
 | 
			
		||||
        lin = nn.Linear(2, 2)
 | 
			
		||||
        x = mx.array([0.1, 0.2])
 | 
			
		||||
 | 
			
		||||
        expected_y = lin(x)
 | 
			
		||||
        y = nn.utils.checkpoint(lin)(x)
 | 
			
		||||
        self.assertTrue(mx.allclose(expected_y, y))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user