docs update

This commit is contained in:
Awni Hannun 2024-02-08 12:44:23 -08:00 committed by CircleCI Docs
parent 17470bf630
commit 9e69a72b8c
437 changed files with 11568 additions and 13689 deletions

View File

@ -1,4 +1,4 @@
# Sphinx build info version 1
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
config: df22fdae6eaa6299681f0aab7c5d6029
config: b49cb089891263e82aedf5bc4cacbe8a
tags: 645f666f9bcd5a90fca523b33c5a78b7

View File

@ -677,9 +677,9 @@ Let's look at the overall directory structure first.
Binding to Python
^^^^^^^^^^^^^^^^^^
We use PyBind11_ to build a Python API for the C++ library. Since bindings
for all needed components such as `mlx.core.array`, `mlx.core.stream`, etc.
are already provided, adding our :meth:`axpby` becomes very simple!
We use PyBind11_ to build a Python API for the C++ library. Since bindings for
components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
already provided, adding our :meth:`axpby` is simple!
.. code-block:: C++
@ -927,18 +927,18 @@ Results:
We see some modest improvements right away!
This operation is now good to be used to build other operations,
in :class:`mlx.nn.Module` calls, and also as a part of graph
transformations like :meth:`grad`!
This operation is now good to be used to build other operations, in
:class:`mlx.nn.Module` calls, and also as a part of graph transformations like
:meth:`grad`!
Scripts
-------
.. admonition:: Download the code
The full example code is available in `mlx-examples <code>`_.
The full example code is available in `mlx <code>`_.
.. code: `TODO_LINK/extensions`_
.. code: `https://github.com/ml-explore/mlx/tree/main/examples/extensions/`_
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
.. _Metal: https://developer.apple.com/documentation/metal?language=objc

View File

@ -41,6 +41,7 @@ are the CPU and GPU.
usage/indexing
usage/saving_and_loading
usage/function_transforms
usage/compile
usage/numpy
usage/using_streams

View File

@ -0,0 +1,6 @@
mlx.core.compile
================
.. currentmodule:: mlx.core
.. autofunction:: compile

View File

@ -0,0 +1,6 @@
mlx.core.disable\_compile
=========================
.. currentmodule:: mlx.core
.. autofunction:: disable_compile

View File

@ -0,0 +1,6 @@
mlx.core.enable\_compile
========================
.. currentmodule:: mlx.core
.. autofunction:: enable_compile

View File

@ -1,6 +0,0 @@
mlx.core.simplify
=================
.. currentmodule:: mlx.core
.. autofunction:: simplify

View File

@ -14,5 +14,6 @@
~AdaDelta.__init__
~AdaDelta.apply_single
~AdaDelta.init_single

View File

@ -14,5 +14,6 @@
~Adafactor.__init__
~Adafactor.apply_single
~Adafactor.init_single

View File

@ -14,5 +14,6 @@
~Adagrad.__init__
~Adagrad.apply_single
~Adagrad.init_single

View File

@ -14,5 +14,6 @@
~Adam.__init__
~Adam.apply_single
~Adam.init_single

View File

@ -14,5 +14,6 @@
~Adamax.__init__
~Adamax.apply_single
~Adamax.init_single

View File

@ -14,5 +14,6 @@
~Lion.__init__
~Lion.apply_single
~Lion.init_single

View File

@ -0,0 +1,6 @@
mlx.optimizers.Optimizer.apply\_gradients
=========================================
.. currentmodule:: mlx.optimizers
.. automethod:: Optimizer.apply_gradients

View File

@ -0,0 +1,6 @@
mlx.optimizers.Optimizer.init
=============================
.. currentmodule:: mlx.optimizers
.. automethod:: Optimizer.init

View File

@ -1,20 +0,0 @@
mlx.optimizers.Optimizer
========================
.. currentmodule:: mlx.optimizers
.. autoclass:: Optimizer
.. rubric:: Methods
.. autosummary::
~Optimizer.__init__
~Optimizer.apply_gradients
~Optimizer.apply_single
~Optimizer.update

View File

@ -0,0 +1,6 @@
mlx.optimizers.Optimizer.state
==============================
.. currentmodule:: mlx.optimizers
.. autoproperty:: Optimizer.state

View File

@ -0,0 +1,6 @@
mlx.optimizers.Optimizer.update
===============================
.. currentmodule:: mlx.optimizers
.. automethod:: Optimizer.update

View File

@ -1,17 +0,0 @@
mlx.optimizers.OptimizerState
=============================
.. currentmodule:: mlx.optimizers
.. autoclass:: OptimizerState
.. rubric:: Methods
.. autosummary::
~OptimizerState.get

View File

@ -14,5 +14,6 @@
~RMSprop.__init__
~RMSprop.apply_single
~RMSprop.init_single

View File

@ -14,5 +14,6 @@
~SGD.__init__
~SGD.apply_single
~SGD.init_single

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: ALiBi

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: BatchNorm

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: Conv1d

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: Conv2d

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: Dropout

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: Dropout2d

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: Dropout3d

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: Embedding

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: GELU

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: GroupNorm

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: InstanceNorm

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: LayerNorm

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: Linear

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: Mish

View File

@ -0,0 +1,6 @@
mlx.nn.Module.state
===================
.. currentmodule:: mlx.nn
.. autoproperty:: Module.state

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: MultiHeadAttention

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: PReLU

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: QuantizedLinear

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: RMSNorm

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: ReLU

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: RoPE

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: SELU

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: Sequential

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: SiLU

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: SinusoidalPositionalEncoding

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: Softshrink

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: Step

View File

@ -4,5 +4,3 @@
.. currentmodule:: mlx.nn
.. autoclass:: Transformer

View File

@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
.. autoclass:: gelu
.. autofunction:: gelu

View File

@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
.. autoclass:: gelu_approx
.. autofunction:: gelu_approx

View File

@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
.. autoclass:: gelu_fast_approx
.. autofunction:: gelu_fast_approx

View File

@ -1,6 +0,0 @@
mlx.nn.init.constant
====================
.. currentmodule:: mlx.nn.init
.. autofunction:: constant

View File

@ -1,6 +0,0 @@
mlx.nn.init.glorot\_normal
==========================
.. currentmodule:: mlx.nn.init
.. autofunction:: glorot_normal

View File

@ -1,6 +0,0 @@
mlx.nn.init.glorot\_uniform
===========================
.. currentmodule:: mlx.nn.init
.. autofunction:: glorot_uniform

View File

@ -1,6 +0,0 @@
mlx.nn.init.he\_normal
======================
.. currentmodule:: mlx.nn.init
.. autofunction:: he_normal

View File

@ -1,6 +0,0 @@
mlx.nn.init.he\_uniform
=======================
.. currentmodule:: mlx.nn.init
.. autofunction:: he_uniform

View File

@ -1,6 +0,0 @@
mlx.nn.init.identity
====================
.. currentmodule:: mlx.nn.init
.. autofunction:: identity

View File

@ -1,6 +0,0 @@
mlx.nn.init.normal
==================
.. currentmodule:: mlx.nn.init
.. autofunction:: normal

View File

@ -1,6 +0,0 @@
mlx.nn.init.uniform
===================
.. currentmodule:: mlx.nn.init
.. autofunction:: uniform

View File

@ -1,6 +0,0 @@
mlx.nn.initializers.constant
============================
.. currentmodule:: mlx.nn.initializers
.. autofunction:: constant

View File

@ -1,6 +0,0 @@
mlx.nn.initializers.glorot\_normal
==================================
.. currentmodule:: mlx.nn.initializers
.. autofunction:: glorot_normal

View File

@ -1,6 +0,0 @@
mlx.nn.initializers.glorot\_uniform
===================================
.. currentmodule:: mlx.nn.initializers
.. autofunction:: glorot_uniform

View File

@ -1,6 +0,0 @@
mlx.nn.initializers.he\_normal
==============================
.. currentmodule:: mlx.nn.initializers
.. autofunction:: he_normal

View File

@ -1,6 +0,0 @@
mlx.nn.initializers.he\_uniform
===============================
.. currentmodule:: mlx.nn.initializers
.. autofunction:: he_uniform

View File

@ -1,6 +0,0 @@
mlx.nn.initializers.identity
============================
.. currentmodule:: mlx.nn.initializers
.. autofunction:: identity

View File

@ -1,6 +0,0 @@
mlx.nn.initializers.normal
==========================
.. currentmodule:: mlx.nn.initializers
.. autofunction:: normal

View File

@ -1,6 +0,0 @@
mlx.nn.initializers.uniform
===========================
.. currentmodule:: mlx.nn.initializers
.. autofunction:: uniform

View File

@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
.. autoclass:: binary_cross_entropy
.. autofunction:: binary_cross_entropy

View File

@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
.. autoclass:: cosine_similarity_loss
.. autofunction:: cosine_similarity_loss

View File

@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
.. autoclass:: cross_entropy
.. autofunction:: cross_entropy

View File

@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
.. autoclass:: gaussian_nll_loss
.. autofunction:: gaussian_nll_loss

View File

@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
.. autoclass:: hinge_loss
.. autofunction:: hinge_loss

View File

@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
.. autoclass:: huber_loss
.. autofunction:: huber_loss

View File

@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
.. autoclass:: kl_div_loss
.. autofunction:: kl_div_loss

View File

@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
.. autoclass:: l1_loss
.. autofunction:: l1_loss

View File

@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
.. autoclass:: log_cosh_loss
.. autofunction:: log_cosh_loss

View File

@ -0,0 +1,6 @@
mlx.nn.losses.margin\_ranking\_loss
===================================
.. currentmodule:: mlx.nn.losses
.. autofunction:: margin_ranking_loss

View File

@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
.. autoclass:: mse_loss
.. autofunction:: mse_loss

View File

@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
.. autoclass:: nll_loss
.. autofunction:: nll_loss

View File

@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
.. autoclass:: smooth_l1_loss
.. autofunction:: smooth_l1_loss

View File

@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn.losses
.. autoclass:: triplet_loss
.. autofunction:: triplet_loss

View File

@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
.. autoclass:: mish
.. autofunction:: mish

View File

@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
.. autoclass:: prelu
.. autofunction:: prelu

View File

@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
.. autoclass:: relu
.. autofunction:: relu

View File

@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
.. autoclass:: selu
.. autofunction:: selu

View File

@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
.. autoclass:: silu
.. autofunction:: silu

View File

@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
.. autoclass:: softshrink
.. autofunction:: softshrink

View File

@ -3,6 +3,4 @@
.. currentmodule:: mlx.nn
.. autoclass:: step
.. autofunction:: step

View File

@ -1,18 +0,0 @@
.. _initializers:
.. currentmodule:: mlx.nn.initializers
Initializers
--------------
.. autosummary::
:toctree: _autosummary_functions
constant
normal
uniform
identity
glorot_normal
glorot_uniform
he_normal
he_uniform

View File

@ -18,6 +18,7 @@ Loss Functions
kl_div_loss
l1_loss
log_cosh_loss
margin_ranking_loss
mse_loss
nll_loss
smooth_l1_loss

View File

@ -11,6 +11,7 @@ Module
:toctree: _autosummary
Module.training
Module.state
.. rubric:: Methods

View File

@ -0,0 +1,23 @@
Optimizer
=========
.. currentmodule:: mlx.optimizers
.. autoclass:: Optimizer
.. rubric:: Attributes
.. autosummary::
:toctree: _autosummary
Optimizer.state
.. rubric:: Methods
.. autosummary::
:toctree: _autosummary
Optimizer.apply_gradients
Optimizer.init
Optimizer.update

View File

@ -29,14 +29,16 @@ model's parameters and the **optimizer state**.
# Compute the new parameters but also the optimizer state.
mx.eval(model.parameters(), optimizer.state)
.. toctree::
optimizer
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
:template: optimizers-template.rst
OptimizerState
Optimizer
SGD
RMSprop
Adagrad

View File

@ -9,6 +9,9 @@ Transforms
:toctree: _autosummary
eval
compile
disable_compile
enable_compile
grad
value_and_grad
jvp

View File

@ -0,0 +1,430 @@
.. _compile:
Compilation
===========
.. currentmodule:: mlx.core
MLX has a :func:`compile` function transformation which compiles computation
graphs. Function compilation results in smaller graphs by merging common work
and fusing certain operations. In many cases this can lead to big improvements
in run-time and memory use.
Getting started with :func:`compile` is simple, but there are some edge cases
that are good to be aware of for more complex graphs and advanced usage.
Basics of Compile
-----------------
Let's start with a simple example:
.. code-block:: python
def fun(x, y):
return mx.exp(-x) + y
x = mx.array(1.0)
y = mx.array(2.0)
# Regular call, no compilation
# Prints: array(2.36788, dtype=float32)
print(fun(x, y))
# Compile the function
compiled_fun = mx.compile(fun)
# Prints: array(2.36788, dtype=float32)
print(compiled_fun(x, y))
The output of both the regular function and the compiled function is the same
up to numerical precision.
The first time you call a compiled function, MLX will build the compute
graph, optimize it, and generate and compile code. This can be relatively
slow. However, MLX will cache compiled functions, so calling a compiled
function multiple times will not initiate a new compilation. This means you
should typically compile functions that you plan to use more than once.
.. code-block:: python
def fun(x, y):
return mx.exp(-x) + y
x = mx.array(1.0)
y = mx.array(2.0)
compiled_fun = mx.compile(fun)
# Compiled here
compiled_fun(x, y)
# Not compiled again
compiled_fun(x, y)
# Not compiled again
mx.compile(fun)(x, y)
There are some important cases to be aware of that can cause a function to
be recompiled:
* Changing the shape or number of dimensions
* Changing the type of any of the inputs
* Changing the number of inputs to the function
In certain cases only some of the compilation stack will be rerun (for
example when changing the shapes) and in other cases the full compilation
stack will be rerun (for example when changing the types). In general you
should avoid compiling functions too frequently.
Another idiom to watch out for is compiling functions which get created and
destroyed frequently. This can happen, for example, when compiling an anonymous
function in a loop:
.. code-block:: python
a = mx.array(1.0)
# Don't do this, compiles lambda at each iteration
for _ in range(5):
mx.compile(lambda x: mx.exp(mx.abs(x)))(a)
Example Speedup
---------------
The :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with
Transformer-based models. The implementation involves several unary and binary
element-wise operations:
.. code-block:: python
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
If you use this function with small arrays, it will be overhead bound. If you
use it with large arrays it will be memory bandwidth bound. However, all of
the operations in the ``gelu`` are fusible into a single kernel with
:func:`compile`. This can speedup both cases considerably.
Let's compare the runtime of the regular function versus the compiled
function. We'll use the following timing helper which does a warm up and
handles synchronization:
.. code-block:: python
import time
def timeit(fun, x):
# warm up
for _ in range(10):
mx.eval(fun(x))
tic = time.perf_counter()
for _ in range(100):
mx.eval(fun(x))
toc = time.perf_counter()
tpi = 1e3 * (toc - tic) / 100
print(f"Time per iteration {tpi:.3f} (ms)")
Now make an array, and benchmark both functions:
.. code-block:: python
x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(nn.gelu, x)
timeit(mx.compile(nn.gelu), x)
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster.
.. note::
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
functions can still be helpful, but won't typically result in as large a
speedup as compiling operations that run on the GPU.
Debugging
---------
When a compiled function is first called, it is traced with placeholder
inputs. This means you can't evaluate arrays (for example to print their
contents) inside compiled functions.
.. code-block:: python
@mx.compile
def fun(x):
z = -x
print(z) # Crash
return mx.exp(z)
fun(mx.array(5.0))
For debugging, inspecting arrays can be helpful. One way to do that is to
globally disable compilation using the :func:`disable_compile` function or
``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though
``fun`` is compiled:
.. code-block:: python
@mx.compile
def fun(x):
z = -x
print(z) # Okay
return mx.exp(z)
mx.disable_compile()
fun(mx.array(5.0))
Pure Functions
--------------
Compiled functions are intended to be *pure*; that is they should not have side
effects. For example:
.. code-block:: python
state = []
@mx.compile
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z)
fun(mx.array(1.0), mx.array(2.0))
# Crash!
print(state)
After the first call of ``fun``, the ``state`` list will hold a placeholder
array. The placeholder does not have any data; it is only used to build the
computation graph. Printing such an array results in a crash.
You have two options to deal with this. The first option is to simply return
``state`` as an output:
.. code-block:: python
state = []
@mx.compile
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z), state
_, state = fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)
In some cases returning updated state can be pretty inconvenient. Hence,
:func:`compile` has a parameter to capture implicit outputs:
.. code-block:: python
from functools import partial
state = []
# Tell compile to capture state as an output
@partial(mx.compile, outputs=state)
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z), state
fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)
This is particularly useful for compiling a function which includes an update
to a container of arrays, as is commonly done when training the parameters of a
:class:`mlx.nn.Module`.
Compiled functions will also treat any inputs not in the parameter list as
constants. For example:
.. code-block:: python
state = [mx.array(1.0)]
@mx.compile
def fun(x):
return x + state[0]
# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
# Update state
state[0] = mx.array(5.0)
# Still prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
In order to have the change of state reflected in the outputs of ``fun`` you
again have two options. The first option is to simply pass ``state`` as input
to the function. In some cases this can be pretty inconvenient. Hence,
:func:`compile` also has a parameter to capture implicit inputs:
.. code-block:: python
from functools import partial
state = [mx.array(1.0)]
# Tell compile to capture state as an input
@partial(mx.compile, inputs=state)
def fun(x):
return x + state[0]
# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
# Update state
state[0] = mx.array(5.0)
# Prints array(6, dtype=float32)
print(fun(mx.array(1.0)))
Compiling Training Graphs
-------------------------
This section will step through how to use :func:`compile` with a simple example
of a common setup: training a model with :obj:`mlx.nn.Module` using an
:obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the
full forward, backward, and update with :func:`compile`.
To start, here is the simple example without any compilation:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))
# 0, 1 targets
y = mx.array([0, 1, 0, 1])
# Simple linear model
model = nn.Linear(10, 1)
# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
def loss_fn(model, x, y):
logits = model(x).squeeze()
return nn.losses.binary_cross_entropy(logits, y)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
# Perform 10 steps of gradient descent
for it in range(10):
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
To compile the update we can put it all in a function and compile it with the
appropriate input and output captures. Here's the same example but compiled:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from functools import partial
# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))
# 0, 1 targets
y = mx.array([0, 1, 0, 1])
# Simple linear model
model = nn.Linear(10, 1)
# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
def loss_fn(model, x, y):
logits = model(x).squeeze()
return nn.losses.binary_cross_entropy(logits, y)
# The state that will be captured as input and output
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x, y):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
return loss
# Perform 10 steps of gradient descent
for it in range(10):
loss = step(x, y)
# Evaluate the model and optimizer state
mx.eval(state)
print(loss)
.. note::
If you are using a module which performs random sampling such as
:func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the
``state`` captured by :func:`compile`, i.e. ``state = [model.state,
optimizer.state, mx.random.state]``.
.. note::
For more examples of compiling full training graphs checkout the `MLX
Examples <https://github.com/ml-explore/mlx-examples>`_ GitHub repo.
Transformations with Compile
----------------------------
In MLX function transformations are composable. You can apply any function
transformation to the output of any other function transformation. For more on
this, see the documentation on :ref:`function transforms
<function_transforms>`.
Compiling transformed functions works just as expected:
.. code-block:: python
grad_fn = mx.grad(mx.exp)
compiled_grad_fn = mx.compile(grad_fn)
# Prints: array(2.71828, dtype=float32)
print(grad_fn(mx.array(1.0)))
# Also prints: array(2.71828, dtype=float32)
print(compiled_grad_fn(mx.array(1.0)))
.. note::
In order to compile as much as possible, a transformation of a compiled
function will not by default be compiled. To compile the transformed
function simply pass it through :func:`compile`.
You can also compile functions which themselves call compiled functions. A
good practice is to compile the outer most function to give :func:`compile`
the most opportunity to optimize the computation graph:
.. code-block:: python
@mx.compile
def inner(x):
return mx.exp(-mx.abs(x))
def outer(x):
inner(inner(x))
# Compiling the outer function is good to do as it will likely
# be faster even though the inner functions are compiled
fun = mx.compile(outer)

View File

@ -5,9 +5,12 @@ Function Transforms
.. currentmodule:: mlx.core
MLX uses composable function transformations for automatic differentiation and
vectorization. The key idea behind composable function transformations is that
every transformation returns a function which can be further transformed.
MLX uses composable function transformations for automatic differentiation,
vectorization, and compute graph optimizations. To see the complete list of
function transformations check-out the :ref:`API documentation <transforms>`.
The key idea behind composable function transformations is that every
transformation returns a function which can be further transformed.
Here is a simple example:
@ -36,10 +39,10 @@ Using :func:`grad` on the output of :func:`grad` is always ok. You keep
getting higher order derivatives.
Any of the MLX function transformations can be composed in any order to any
depth. To see the complete list of function transformations check-out the
:ref:`API documentation <transforms>`. See the following sections for more
information on :ref:`automatic differentiaion <auto diff>` and
:ref:`automatic vectorization <vmap>`.
depth. See the following sections for more information on :ref:`automatic
differentiaion <auto diff>` and :ref:`automatic vectorization <vmap>`.
For more information on :func:`compile` see the :ref:`compile documentation <compile>`.
Automatic Differentiation
-------------------------

View File

@ -1,6 +1,6 @@
var DOCUMENTATION_OPTIONS = {
URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'),
VERSION: '0.1.0',
VERSION: '0.2.0',
LANGUAGE: 'en',
COLLAPSE_INDEX: false,
BUILDER: 'html',

View File

@ -9,7 +9,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
<title>Operations &#8212; MLX 0.1.0 documentation</title>
<title>Operations &#8212; MLX 0.2.0 documentation</title>
@ -134,8 +134,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.1.0 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo.png" class="logo__image only-dark" alt="MLX 0.1.0 documentation - Home"/>`);</script>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.2.0 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo.png" class="logo__image only-dark" alt="MLX 0.2.0 documentation - Home"/>`);</script>
</a></div>
@ -153,6 +153,7 @@
<li class="toctree-l1"><a class="reference internal" href="../usage/indexing.html">Indexing Arrays</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/saving_and_loading.html">Saving and Loading Arrays</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/function_transforms.html">Function Transforms</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/compile.html">Compilation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/numpy.html">Conversion to NumPy and Other Frameworks</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/using_streams.html">Using Streams</a></li>
</ul>
@ -348,6 +349,9 @@
</li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/transforms.html">Transforms</a><input class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-5"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.eval.html">mlx.core.eval</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.compile.html">mlx.core.compile</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.disable_compile.html">mlx.core.disable_compile</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.enable_compile.html">mlx.core.enable_compile</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.grad.html">mlx.core.grad</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.value_and_grad.html">mlx.core.value_and_grad</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.jvp.html">mlx.core.jvp</a></li>
@ -379,6 +383,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.nn.value_and_grad.html">mlx.nn.value_and_grad</a></li>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/module.html">Module</a><input class="toctree-checkbox" id="toctree-checkbox-9" name="toctree-checkbox-9" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-9"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.training.html">mlx.nn.Module.training</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.state.html">mlx.nn.Module.state</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.apply.html">mlx.nn.Module.apply</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.apply_to_modules.html">mlx.nn.Module.apply_to_modules</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.children.html">mlx.nn.Module.children</a></li>
@ -451,6 +456,7 @@
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.html">mlx.nn.losses.kl_div_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.html">mlx.nn.losses.l1_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.html">mlx.nn.losses.log_cosh_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.margin_ranking_loss.html">mlx.nn.losses.margin_ranking_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.html">mlx.nn.losses.mse_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.html">mlx.nn.losses.nll_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.html">mlx.nn.losses.smooth_l1_loss</a></li>
@ -471,8 +477,13 @@
</ul>
</li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/optimizers.html">Optimizers</a><input class="toctree-checkbox" id="toctree-checkbox-14" name="toctree-checkbox-14" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-14"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.OptimizerState.html">mlx.optimizers.OptimizerState</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.Optimizer.html">mlx.optimizers.Optimizer</a></li>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/optimizer.html">Optimizer</a><input class="toctree-checkbox" id="toctree-checkbox-15" name="toctree-checkbox-15" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-15"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l3"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.Optimizer.state.html">mlx.optimizers.Optimizer.state</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.Optimizer.apply_gradients.html">mlx.optimizers.Optimizer.apply_gradients</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.Optimizer.init.html">mlx.optimizers.Optimizer.init</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.Optimizer.update.html">mlx.optimizers.Optimizer.update</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.SGD.html">mlx.optimizers.SGD</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.RMSprop.html">mlx.optimizers.RMSprop</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.Adagrad.html">mlx.optimizers.Adagrad</a></li>
@ -484,7 +495,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.Lion.html">mlx.optimizers.Lion</a></li>
</ul>
</li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/tree_utils.html">Tree Utils</a><input class="toctree-checkbox" id="toctree-checkbox-15" name="toctree-checkbox-15" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-15"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/tree_utils.html">Tree Utils</a><input class="toctree-checkbox" id="toctree-checkbox-16" name="toctree-checkbox-16" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-16"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_flatten.html">mlx.utils.tree_flatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_unflatten.html">mlx.utils.tree_unflatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_map.html">mlx.utils.tree_map</a></li>

View File

@ -9,7 +9,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
<title>Developer Documentation &#8212; MLX 0.1.0 documentation</title>
<title>Developer Documentation &#8212; MLX 0.2.0 documentation</title>
@ -133,8 +133,8 @@
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.1.0 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo.png" class="logo__image only-dark" alt="MLX 0.1.0 documentation - Home"/>`);</script>
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.2.0 documentation - Home"/>
<script>document.write(`<img src="../_static/mlx_logo.png" class="logo__image only-dark" alt="MLX 0.2.0 documentation - Home"/>`);</script>
</a></div>
@ -152,6 +152,7 @@
<li class="toctree-l1"><a class="reference internal" href="../usage/indexing.html">Indexing Arrays</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/saving_and_loading.html">Saving and Loading Arrays</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/function_transforms.html">Function Transforms</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/compile.html">Compilation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/numpy.html">Conversion to NumPy and Other Frameworks</a></li>
<li class="toctree-l1"><a class="reference internal" href="../usage/using_streams.html">Using Streams</a></li>
</ul>
@ -347,6 +348,9 @@
</li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/transforms.html">Transforms</a><input class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-5"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.eval.html">mlx.core.eval</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.compile.html">mlx.core.compile</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.disable_compile.html">mlx.core.disable_compile</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.enable_compile.html">mlx.core.enable_compile</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.grad.html">mlx.core.grad</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.value_and_grad.html">mlx.core.value_and_grad</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.core.jvp.html">mlx.core.jvp</a></li>
@ -378,6 +382,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.nn.value_and_grad.html">mlx.nn.value_and_grad</a></li>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/nn/module.html">Module</a><input class="toctree-checkbox" id="toctree-checkbox-9" name="toctree-checkbox-9" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-9"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.training.html">mlx.nn.Module.training</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.state.html">mlx.nn.Module.state</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.apply.html">mlx.nn.Module.apply</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.apply_to_modules.html">mlx.nn.Module.apply_to_modules</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary/mlx.nn.Module.children.html">mlx.nn.Module.children</a></li>
@ -450,6 +455,7 @@
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.kl_div_loss.html">mlx.nn.losses.kl_div_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.l1_loss.html">mlx.nn.losses.l1_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.log_cosh_loss.html">mlx.nn.losses.log_cosh_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.margin_ranking_loss.html">mlx.nn.losses.margin_ranking_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.mse_loss.html">mlx.nn.losses.mse_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.nll_loss.html">mlx.nn.losses.nll_loss</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/nn/_autosummary_functions/mlx.nn.losses.smooth_l1_loss.html">mlx.nn.losses.smooth_l1_loss</a></li>
@ -470,8 +476,13 @@
</ul>
</li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/optimizers.html">Optimizers</a><input class="toctree-checkbox" id="toctree-checkbox-14" name="toctree-checkbox-14" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-14"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.OptimizerState.html">mlx.optimizers.OptimizerState</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.Optimizer.html">mlx.optimizers.Optimizer</a></li>
<li class="toctree-l2 has-children"><a class="reference internal" href="../python/optimizer.html">Optimizer</a><input class="toctree-checkbox" id="toctree-checkbox-15" name="toctree-checkbox-15" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-15"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l3"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.Optimizer.state.html">mlx.optimizers.Optimizer.state</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.Optimizer.apply_gradients.html">mlx.optimizers.Optimizer.apply_gradients</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.Optimizer.init.html">mlx.optimizers.Optimizer.init</a></li>
<li class="toctree-l3"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.Optimizer.update.html">mlx.optimizers.Optimizer.update</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.SGD.html">mlx.optimizers.SGD</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.RMSprop.html">mlx.optimizers.RMSprop</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.Adagrad.html">mlx.optimizers.Adagrad</a></li>
@ -483,7 +494,7 @@
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.optimizers.Lion.html">mlx.optimizers.Lion</a></li>
</ul>
</li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/tree_utils.html">Tree Utils</a><input class="toctree-checkbox" id="toctree-checkbox-15" name="toctree-checkbox-15" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-15"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l1 has-children"><a class="reference internal" href="../python/tree_utils.html">Tree Utils</a><input class="toctree-checkbox" id="toctree-checkbox-16" name="toctree-checkbox-16" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-16"><i class="fa-solid fa-chevron-down"></i></label><ul>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_flatten.html">mlx.utils.tree_flatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_unflatten.html">mlx.utils.tree_unflatten</a></li>
<li class="toctree-l2"><a class="reference internal" href="../python/_autosummary/mlx.utils.tree_map.html">mlx.utils.tree_map</a></li>
@ -1326,9 +1337,9 @@ the python package</p></li>
</ul>
<section id="binding-to-python">
<h3>Binding to Python<a class="headerlink" href="#binding-to-python" title="Permalink to this heading">#</a></h3>
<p>We use <a class="reference external" href="https://pybind11.readthedocs.io/en/stable/">PyBind11</a> to build a Python API for the C++ library. Since bindings
for all needed components such as <cite>mlx.core.array</cite>, <cite>mlx.core.stream</cite>, etc.
are already provided, adding our <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> becomes very simple!</p>
<p>We use <a class="reference external" href="https://pybind11.readthedocs.io/en/stable/">PyBind11</a> to build a Python API for the C++ library. Since bindings for
components such as <a class="reference internal" href="../python/_autosummary/mlx.core.array.html#mlx.core.array" title="mlx.core.array"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.array</span></code></a>, <code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.core.stream</span></code>, etc. are
already provided, adding our <code class="xref py py-meth docutils literal notranslate"><span class="pre">axpby()</span></code> is simple!</p>
<div class="highlight-C++ notranslate"><div class="highlight"><pre><span></span><span class="n">PYBIND11_MODULE</span><span class="p">(</span><span class="n">mlx_sample_extensions</span><span class="p">,</span><span class="w"> </span><span class="n">m</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
<span class="w"> </span><span class="n">m</span><span class="p">.</span><span class="n">doc</span><span class="p">()</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s">&quot;Sample C++ and metal extensions for MLX&quot;</span><span class="p">;</span><span class="w"></span>
@ -1552,16 +1563,16 @@ with the naive <code class="xref py py-meth docutils literal notranslate"><span
</pre></div>
</div>
<p>We see some modest improvements right away!</p>
<p>This operation is now good to be used to build other operations,
in <a class="reference internal" href="../python/nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Module</span></code></a> calls, and also as a part of graph
transformations like <code class="xref py py-meth docutils literal notranslate"><span class="pre">grad()</span></code>!</p>
<p>This operation is now good to be used to build other operations, in
<a class="reference internal" href="../python/nn/module.html#mlx.nn.Module" title="mlx.nn.Module"><code class="xref py py-class docutils literal notranslate"><span class="pre">mlx.nn.Module</span></code></a> calls, and also as a part of graph transformations like
<code class="xref py py-meth docutils literal notranslate"><span class="pre">grad()</span></code>!</p>
</section>
</section>
<section id="scripts">
<h2>Scripts<a class="headerlink" href="#scripts" title="Permalink to this heading">#</a></h2>
<div class="admonition-download-the-code admonition">
<p class="admonition-title">Download the code</p>
<p>The full example code is available in <a class="reference external" href="code">mlx-examples</a>.</p>
<p>The full example code is available in <a class="reference external" href="code">mlx</a>.</p>
</div>
</section>
</section>

Some files were not shown because too many files have changed in this diff Show More