Added clarification to apply_fn parameter of apply_to_modules (#2831)

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Chaoran Yu
2025-11-26 15:40:56 -08:00
committed by GitHub
parent dd79d3c465
commit b054838780
2 changed files with 28 additions and 5 deletions

View File

@@ -407,7 +407,10 @@ class Module(dict):
instance). instance).
Args: Args:
apply_fn (Callable): The function to apply to the modules. apply_fn (Callable): The function to apply to the modules which
takes two parameters. The first parameter is the string path of
the module (e.g. ``"model.layers.0.linear"``). The second
parameter is the module object.
Returns: Returns:
The module instance after updating submodules. The module instance after updating submodules.

View File

@@ -1238,8 +1238,18 @@ void init_transforms(nb::module_& m) {
same in number, shape, and type as the inputs of ``fun`` (i.e. the ``primals``). same in number, shape, and type as the inputs of ``fun`` (i.e. the ``primals``).
Returns: Returns:
list(array): A list of the Jacobian-vector products which tuple(list(array), list(array)): A tuple with the outputs of
is the same in number, shape, and type of the inputs to ``fun``. ``fun`` in the first position and the Jacobian-vector products
in the second position.
Example:
.. code-block:: python
import mlx.core as mx
outs, jvps = mx.jvp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))
)pbdoc"); )pbdoc");
m.def( m.def(
"vjp", "vjp",
@@ -1277,8 +1287,18 @@ void init_transforms(nb::module_& m) {
same in number, shape, and type as the outputs of ``fun``. same in number, shape, and type as the outputs of ``fun``.
Returns: Returns:
list(array): A list of the vector-Jacobian products which tuple(list(array), list(array)): A tuple with the outputs of
is the same in number, shape, and type of the outputs of ``fun``. ``fun`` in the first position and the vector-Jacobian products
in the second position.
Example:
.. code-block:: python
import mlx.core as mx
outs, vjps = mx.vjp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))
)pbdoc"); )pbdoc");
m.def( m.def(
"value_and_grad", "value_and_grad",