mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 02:28:13 +08:00 
			
		
		
		
	checkpoint module's __call__
This commit is contained in:
		@@ -37,25 +37,23 @@ def value_and_grad(model: Module, fn: Callable):
 | 
				
			|||||||
    return wrapped_value_grad_fn
 | 
					    return wrapped_value_grad_fn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def checkpoint(module: Module, fn: Callable = None):
 | 
					def checkpoint(module: Module):
 | 
				
			||||||
    """Transform the passed callable to one that performs gradient
 | 
					    """Transform the passed module to one that performs gradient
 | 
				
			||||||
    checkpointing with respect to the trainable parameters of the module (and
 | 
					    checkpointing.
 | 
				
			||||||
    the callable's inputs).
 | 
					
 | 
				
			||||||
 | 
					    The checkpointing is with respect to the module's trainable parameters and
 | 
				
			||||||
 | 
					    inputs of the module's ``__call__`` function.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Args:
 | 
					    Args:
 | 
				
			||||||
        module (mlx.nn.Module): The module for whose parameters we will be
 | 
					        module (mlx.nn.Module): The module for whose parameters we will be
 | 
				
			||||||
            performing gradient checkpointing.
 | 
					            performing gradient checkpointing.
 | 
				
			||||||
        fn (Callable, optional): The function to checkpoint. If not provided it
 | 
					 | 
				
			||||||
            defaults to the provided module.
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Returns:
 | 
					    Returns:
 | 
				
			||||||
        A callable that saves the inputs and outputs during the forward pass
 | 
					        The module that saves the inputs and outputs during the forward pass
 | 
				
			||||||
        and recomputes all intermediate states during the backward pass.
 | 
					        and recomputes all intermediate states during the backward pass.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    if fn is None:
 | 
					
 | 
				
			||||||
        # Capturing module instead of module.__call__ allows someone to
 | 
					    fn = module.__call__
 | 
				
			||||||
        # monkey-patch __call__ later on and the correct method will be used
 | 
					 | 
				
			||||||
        fn = module
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def inner_fn(params, *args, **kwargs):
 | 
					    def inner_fn(params, *args, **kwargs):
 | 
				
			||||||
        module.update(params)
 | 
					        module.update(params)
 | 
				
			||||||
@@ -67,4 +65,9 @@ def checkpoint(module: Module, fn: Callable = None):
 | 
				
			|||||||
    def wrapped_checkpointed_fn(*args, **kwargs):
 | 
					    def wrapped_checkpointed_fn(*args, **kwargs):
 | 
				
			||||||
        return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
 | 
					        return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return wrapped_checkpointed_fn
 | 
					    class _(type(module)):
 | 
				
			||||||
 | 
					        def __call__(self, *arg, **kwarg):
 | 
				
			||||||
 | 
					            return wrapped_checkpointed_fn(*arg, **kwarg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    module.__class__ = _
 | 
				
			||||||
 | 
					    return module
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1481,5 +1481,16 @@ class TestLayers(mlx_tests.MLXTestCase):
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestNNUtils(mlx_tests.MLXTestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_checkpoint(self):
 | 
				
			||||||
 | 
					        lin = nn.Linear(2, 2)
 | 
				
			||||||
 | 
					        x = mx.array([0.1, 0.2])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        expected_y = lin(x)
 | 
				
			||||||
 | 
					        y = nn.utils.checkpoint(lin)(x)
 | 
				
			||||||
 | 
					        self.assertTrue(mx.allclose(expected_y, y))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    unittest.main()
 | 
					    unittest.main()
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user