Added clarification to apply_fn parameter of apply_to_modules (#2831)

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Chaoran Yu
2025-11-26 15:40:56 -08:00
committed by GitHub
parent dd79d3c465
commit b054838780
2 changed files with 28 additions and 5 deletions

View File

@@ -1238,8 +1238,18 @@ void init_transforms(nb::module_& m) {
same in number, shape, and type as the inputs of ``fun`` (i.e. the ``primals``).
Returns:
list(array): A list of the Jacobian-vector products which
is the same in number, shape, and type of the inputs to ``fun``.
tuple(list(array), list(array)): A tuple with the outputs of
``fun`` in the first position and the Jacobian-vector products
in the second position.
Example:
.. code-block:: python
import mlx.core as mx
outs, jvps = mx.jvp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))
)pbdoc");
m.def(
"vjp",
@@ -1277,8 +1287,18 @@ void init_transforms(nb::module_& m) {
same in number, shape, and type as the outputs of ``fun``.
Returns:
list(array): A list of the vector-Jacobian products which
is the same in number, shape, and type of the outputs of ``fun``.
tuple(list(array), list(array)): A tuple with the outputs of
``fun`` in the first position and the vector-Jacobian products
in the second position.
Example:
.. code-block:: python
import mlx.core as mx
outs, vjps = mx.vjp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))
)pbdoc");
m.def(
"value_and_grad",