mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Update custom function docs (#1748)
This commit is contained in:
parent
b51d70a83c
commit
eab93985b8
@ -971,6 +971,32 @@ void init_transforms(nb::module_& m) {
|
|||||||
if ay != ax and ax is not None:
|
if ay != ax and ax is not None:
|
||||||
y = y.swapaxes(ay, ax)
|
y = y.swapaxes(ay, ax)
|
||||||
return mx.sin(x) * y, (ax or ay)
|
return mx.sin(x) * y, (ax or ay)
|
||||||
|
|
||||||
|
All ``custom_function`` instances behave as pure functions. Namely, any
|
||||||
|
variables captured will be treated as constants and no gradients will be
|
||||||
|
computed with respect to the captured arrays. For instance:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
def g(x, y):
|
||||||
|
@mx.custom_function
|
||||||
|
def f(x):
|
||||||
|
return x * y
|
||||||
|
|
||||||
|
@f.vjp
|
||||||
|
def f_vjp(x, dx, fx):
|
||||||
|
# Note that we have only x, dx and fx and nothing with respect to y
|
||||||
|
raise ValueError("Abort!")
|
||||||
|
|
||||||
|
return f(x)
|
||||||
|
|
||||||
|
x = mx.array(2.0)
|
||||||
|
y = mx.array(3.0)
|
||||||
|
print(g(x, y)) # prints 6.0
|
||||||
|
print(mx.grad(g)(x, y)) # Raises exception
|
||||||
|
print(mx.grad(g, argnums=1)(x, y)) # prints 0.0
|
||||||
)pbdoc")
|
)pbdoc")
|
||||||
.def(
|
.def(
|
||||||
nb::init<nb::callable>(),
|
nb::init<nb::callable>(),
|
||||||
|
Loading…
Reference in New Issue
Block a user