From eab93985b8ca2311b8d067cde2a095215c5237e0 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 3 Jan 2025 16:35:25 -0800 Subject: [PATCH] Update custom function docs (#1748) --- python/src/transforms.cpp | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index ef1edbe25..fdf1dc0e5 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -971,6 +971,32 @@ void init_transforms(nb::module_& m) { if ay != ax and ax is not None: y = y.swapaxes(ay, ax) return mx.sin(x) * y, (ax or ay) + + All ``custom_function`` instances behave as pure functions. Namely, any + variables captured will be treated as constants and no gradients will be + computed with respect to the captured arrays. For instance: + + .. code-block:: python + + import mlx.core as mx + + def g(x, y): + @mx.custom_function + def f(x): + return x * y + + @f.vjp + def f_vjp(x, dx, fx): + # Note that we have only x, dx and fx and nothing with respect to y + raise ValueError("Abort!") + + return f(x) + + x = mx.array(2.0) + y = mx.array(3.0) + print(g(x, y)) # prints 6.0 + print(mx.grad(g)(x, y)) # Raises exception + print(mx.grad(g, argnums=1)(x, y)) # prints 0.0 )pbdoc") .def( nb::init(),