mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Update custom function docs (#1748)
This commit is contained in:
parent
b51d70a83c
commit
eab93985b8
@ -971,6 +971,32 @@ void init_transforms(nb::module_& m) {
|
||||
if ay != ax and ax is not None:
|
||||
y = y.swapaxes(ay, ax)
|
||||
return mx.sin(x) * y, (ax or ay)
|
||||
|
||||
All ``custom_function`` instances behave as pure functions. Namely, any
|
||||
variables captured will be treated as constants and no gradients will be
|
||||
computed with respect to the captured arrays. For instance:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
def g(x, y):
|
||||
@mx.custom_function
|
||||
def f(x):
|
||||
return x * y
|
||||
|
||||
@f.vjp
|
||||
def f_vjp(x, dx, fx):
|
||||
# Note that we have only x, dx and fx and nothing with respect to y
|
||||
raise ValueError("Abort!")
|
||||
|
||||
return f(x)
|
||||
|
||||
x = mx.array(2.0)
|
||||
y = mx.array(3.0)
|
||||
print(g(x, y)) # prints 6.0
|
||||
print(mx.grad(g)(x, y)) # Raises exception
|
||||
print(mx.grad(g, argnums=1)(x, y)) # prints 0.0
|
||||
)pbdoc")
|
||||
.def(
|
||||
nb::init<nb::callable>(),
|
||||
|
Loading…
Reference in New Issue
Block a user