docs for checkpoint + a few more tests

This commit is contained in:
Awni Hannun 2024-03-05 15:34:46 -08:00
parent 1368bce280
commit a5827d0384
6 changed files with 38 additions and 8 deletions

View File

@ -173,6 +173,7 @@ In detail:
:toctree: _autosummary
value_and_grad
checkpoint
.. toctree::

View File

@ -17,3 +17,4 @@ Transforms
jvp
vjp
vmap
checkpoint

View File

@ -2,4 +2,4 @@
from mlx.nn import init, losses
from mlx.nn.layers import *
from mlx.nn.utils import value_and_grad
from mlx.nn.utils import checkpoint, value_and_grad

View File

@ -15,7 +15,7 @@ def value_and_grad(model: Module, fn: Callable):
Args:
model (mlx.nn.Module): The model whose trainable parameters to compute
gradients for
gradients for
fn (Callable): The scalar function to compute gradients for
Returns:
@ -38,15 +38,14 @@ def value_and_grad(model: Module, fn: Callable):
def checkpoint(module: Module):
"""Transform the passed module to one that performs gradient
checkpointing.
"""Transform the passed module to one that performs gradient checkpointing.
The checkpointing is with respect to the module's trainable parameters and
inputs of the module's ``__call__`` function.
Args:
module (mlx.nn.Module): The module for whose parameters we will be
performing gradient checkpointing.
module (mlx.nn.Module): The module for which we will perform gradient
checkpointing.
Returns:
The module that saves the inputs and outputs during the forward pass

View File

@ -888,7 +888,24 @@ void init_transforms(py::module_& m) {
m.def(
"checkpoint",
[](py::function fun) { return py::cpp_function(PyCheckpointedFun{fun}); },
"fun"_a);
"fun"_a,
R"pbdoc(
checkpoint(fun: function) -> function
Returns a gradient checkpointed function.
The checkpointed function produces the same output as the input
``fun`` but recomputes all intermediate states during the gradient
computation (vjp) rather than storing them.
Use the checkpoint transformation to reduce memory consumption at the
cost of increased computation.
Args:
fun (function): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns
a variable number of :class:`array` or trees of :class:`array`.
)pbdoc");
// Register static Python object cleanup before the interpreter exits
auto atexit = py::module_::import("atexit");

View File

@ -1487,10 +1487,22 @@ class TestNNUtils(mlx_tests.MLXTestCase):
lin = nn.Linear(2, 2)
x = mx.array([0.1, 0.2])
lin.my_attr = "hello"
expected_y = lin(x)
y = nn.utils.checkpoint(lin)(x)
clin = nn.utils.checkpoint(lin)
y = clin(x)
self.assertTrue(mx.allclose(expected_y, y))
# Check get/set attribute
self.assertEqual(clin.my_attr, "hello")
clin.my_attr = "bye"
self.assertEqual(clin.my_attr, "bye")
self.assertTrue(isinstance(clin, nn.Linear))
self.assertEqual(repr(clin), repr(lin))
if __name__ == "__main__":
unittest.main()