From b054838780842fa95738a495a1eceb762a786ca5 Mon Sep 17 00:00:00 2001 From: Chaoran Yu Date: Wed, 26 Nov 2025 15:40:56 -0800 Subject: [PATCH] Added clarification to apply_fn parameter of apply_to_modules (#2831) Co-authored-by: Awni Hannun --- python/mlx/nn/layers/base.py | 5 ++++- python/src/transforms.cpp | 28 ++++++++++++++++++++++++---- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 1ae927279..ce009d9f8 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -407,7 +407,10 @@ class Module(dict): instance). Args: - apply_fn (Callable): The function to apply to the modules. + apply_fn (Callable): The function to apply to the modules which + takes two parameters. The first parameter is the string path of + the module (e.g. ``"model.layers.0.linear"``). The second + parameter is the module object. Returns: The module instance after updating submodules. diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 12aa641f8..0530fa089 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -1238,8 +1238,18 @@ void init_transforms(nb::module_& m) { same in number, shape, and type as the inputs of ``fun`` (i.e. the ``primals``). Returns: - list(array): A list of the Jacobian-vector products which - is the same in number, shape, and type of the inputs to ``fun``. + tuple(list(array), list(array)): A tuple with the outputs of + ``fun`` in the first position and the Jacobian-vector products + in the second position. + + Example: + + .. code-block:: python + + import mlx.core as mx + + outs, jvps = mx.jvp(mx.sin, (mx.array(1.0),), (mx.array(1.0),)) + )pbdoc"); m.def( "vjp", @@ -1277,8 +1287,18 @@ void init_transforms(nb::module_& m) { same in number, shape, and type as the outputs of ``fun``. Returns: - list(array): A list of the vector-Jacobian products which - is the same in number, shape, and type of the outputs of ``fun``. + tuple(list(array), list(array)): A tuple with the outputs of + ``fun`` in the first position and the vector-Jacobian products + in the second position. + + Example: + + .. code-block:: python + + import mlx.core as mx + + outs, vjps = mx.vjp(mx.sin, (mx.array(1.0),), (mx.array(1.0),)) + )pbdoc"); m.def( "value_and_grad",