mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-27 19:31:16 +08:00
docs for checkpoint + a few more tests
This commit is contained in:
parent
1368bce280
commit
a5827d0384
@ -173,6 +173,7 @@ In detail:
|
||||
:toctree: _autosummary
|
||||
|
||||
value_and_grad
|
||||
checkpoint
|
||||
|
||||
.. toctree::
|
||||
|
||||
|
@ -17,3 +17,4 @@ Transforms
|
||||
jvp
|
||||
vjp
|
||||
vmap
|
||||
checkpoint
|
||||
|
@ -2,4 +2,4 @@
|
||||
|
||||
from mlx.nn import init, losses
|
||||
from mlx.nn.layers import *
|
||||
from mlx.nn.utils import value_and_grad
|
||||
from mlx.nn.utils import checkpoint, value_and_grad
|
||||
|
@ -15,7 +15,7 @@ def value_and_grad(model: Module, fn: Callable):
|
||||
|
||||
Args:
|
||||
model (mlx.nn.Module): The model whose trainable parameters to compute
|
||||
gradients for
|
||||
gradients for
|
||||
fn (Callable): The scalar function to compute gradients for
|
||||
|
||||
Returns:
|
||||
@ -38,15 +38,14 @@ def value_and_grad(model: Module, fn: Callable):
|
||||
|
||||
|
||||
def checkpoint(module: Module):
|
||||
"""Transform the passed module to one that performs gradient
|
||||
checkpointing.
|
||||
"""Transform the passed module to one that performs gradient checkpointing.
|
||||
|
||||
The checkpointing is with respect to the module's trainable parameters and
|
||||
inputs of the module's ``__call__`` function.
|
||||
|
||||
Args:
|
||||
module (mlx.nn.Module): The module for whose parameters we will be
|
||||
performing gradient checkpointing.
|
||||
module (mlx.nn.Module): The module for which we will perform gradient
|
||||
checkpointing.
|
||||
|
||||
Returns:
|
||||
The module that saves the inputs and outputs during the forward pass
|
||||
|
@ -888,7 +888,24 @@ void init_transforms(py::module_& m) {
|
||||
m.def(
|
||||
"checkpoint",
|
||||
[](py::function fun) { return py::cpp_function(PyCheckpointedFun{fun}); },
|
||||
"fun"_a);
|
||||
"fun"_a,
|
||||
R"pbdoc(
|
||||
checkpoint(fun: function) -> function
|
||||
|
||||
Returns a gradient checkpointed function.
|
||||
|
||||
The checkpointed function produces the same output as the input
|
||||
``fun`` but recomputes all intermediate states during the gradient
|
||||
computation (vjp) rather than storing them.
|
||||
|
||||
Use the checkpoint transformation to reduce memory consumption at the
|
||||
cost of increased computation.
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a variable number of :class:`array` or trees of :class:`array`.
|
||||
)pbdoc");
|
||||
|
||||
// Register static Python object cleanup before the interpreter exits
|
||||
auto atexit = py::module_::import("atexit");
|
||||
|
@ -1487,10 +1487,22 @@ class TestNNUtils(mlx_tests.MLXTestCase):
|
||||
lin = nn.Linear(2, 2)
|
||||
x = mx.array([0.1, 0.2])
|
||||
|
||||
lin.my_attr = "hello"
|
||||
|
||||
expected_y = lin(x)
|
||||
y = nn.utils.checkpoint(lin)(x)
|
||||
clin = nn.utils.checkpoint(lin)
|
||||
y = clin(x)
|
||||
self.assertTrue(mx.allclose(expected_y, y))
|
||||
|
||||
# Check get/set attribute
|
||||
self.assertEqual(clin.my_attr, "hello")
|
||||
|
||||
clin.my_attr = "bye"
|
||||
self.assertEqual(clin.my_attr, "bye")
|
||||
|
||||
self.assertTrue(isinstance(clin, nn.Linear))
|
||||
self.assertEqual(repr(clin), repr(lin))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user