diff --git a/docs/src/python/nn.rst b/docs/src/python/nn.rst index 2a253ab25..e3507f9d3 100644 --- a/docs/src/python/nn.rst +++ b/docs/src/python/nn.rst @@ -173,6 +173,7 @@ In detail: :toctree: _autosummary value_and_grad + checkpoint .. toctree:: diff --git a/docs/src/python/transforms.rst b/docs/src/python/transforms.rst index ad9ba579b..36c2270f9 100644 --- a/docs/src/python/transforms.rst +++ b/docs/src/python/transforms.rst @@ -17,3 +17,4 @@ Transforms jvp vjp vmap + checkpoint diff --git a/python/mlx/nn/__init__.py b/python/mlx/nn/__init__.py index b2cb9e0f4..1a6b65d58 100644 --- a/python/mlx/nn/__init__.py +++ b/python/mlx/nn/__init__.py @@ -2,4 +2,4 @@ from mlx.nn import init, losses from mlx.nn.layers import * -from mlx.nn.utils import value_and_grad +from mlx.nn.utils import checkpoint, value_and_grad diff --git a/python/mlx/nn/utils.py b/python/mlx/nn/utils.py index 07ab21705..ebd340599 100644 --- a/python/mlx/nn/utils.py +++ b/python/mlx/nn/utils.py @@ -15,7 +15,7 @@ def value_and_grad(model: Module, fn: Callable): Args: model (mlx.nn.Module): The model whose trainable parameters to compute - gradients for + gradients for fn (Callable): The scalar function to compute gradients for Returns: @@ -38,15 +38,14 @@ def value_and_grad(model: Module, fn: Callable): def checkpoint(module: Module): - """Transform the passed module to one that performs gradient - checkpointing. + """Transform the passed module to one that performs gradient checkpointing. The checkpointing is with respect to the module's trainable parameters and inputs of the module's ``__call__`` function. Args: - module (mlx.nn.Module): The module for whose parameters we will be - performing gradient checkpointing. + module (mlx.nn.Module): The module for which we will perform gradient + checkpointing. Returns: The module that saves the inputs and outputs during the forward pass diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 1612f774d..f560ae501 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -888,7 +888,24 @@ void init_transforms(py::module_& m) { m.def( "checkpoint", [](py::function fun) { return py::cpp_function(PyCheckpointedFun{fun}); }, - "fun"_a); + "fun"_a, + R"pbdoc( + checkpoint(fun: function) -> function + + Returns a gradient checkpointed function. + + The checkpointed function produces the same output as the input + ``fun`` but recomputes all intermediate states during the gradient + computation (vjp) rather than storing them. + + Use the checkpoint transformation to reduce memory consumption at the + cost of increased computation. + + Args: + fun (function): A function which takes a variable number of + :class:`array` or trees of :class:`array` and returns + a variable number of :class:`array` or trees of :class:`array`. + )pbdoc"); // Register static Python object cleanup before the interpreter exits auto atexit = py::module_::import("atexit"); diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 43cd55177..8cb5f5755 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -1487,10 +1487,22 @@ class TestNNUtils(mlx_tests.MLXTestCase): lin = nn.Linear(2, 2) x = mx.array([0.1, 0.2]) + lin.my_attr = "hello" + expected_y = lin(x) - y = nn.utils.checkpoint(lin)(x) + clin = nn.utils.checkpoint(lin) + y = clin(x) self.assertTrue(mx.allclose(expected_y, y)) + # Check get/set attribute + self.assertEqual(clin.my_attr, "hello") + + clin.my_attr = "bye" + self.assertEqual(clin.my_attr, "bye") + + self.assertTrue(isinstance(clin, nn.Linear)) + self.assertEqual(repr(clin), repr(lin)) + if __name__ == "__main__": unittest.main()