Update custom function docs (#1748)

This commit is contained in:
Angelos Katharopoulos 2025-01-03 16:35:25 -08:00 committed by GitHub
parent b51d70a83c
commit eab93985b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -971,6 +971,32 @@ void init_transforms(nb::module_& m) {
if ay != ax and ax is not None:
y = y.swapaxes(ay, ax)
return mx.sin(x) * y, (ax or ay)
All ``custom_function`` instances behave as pure functions. Namely, any
variables captured will be treated as constants and no gradients will be
computed with respect to the captured arrays. For instance:
.. code-block:: python
import mlx.core as mx
def g(x, y):
@mx.custom_function
def f(x):
return x * y
@f.vjp
def f_vjp(x, dx, fx):
# Note that we have only x, dx and fx and nothing with respect to y
raise ValueError("Abort!")
return f(x)
x = mx.array(2.0)
y = mx.array(3.0)
print(g(x, y)) # prints 6.0
print(mx.grad(g)(x, y)) # Raises exception
print(mx.grad(g, argnums=1)(x, y)) # prints 0.0
)pbdoc")
.def(
nb::init<nb::callable>(),