mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-28 21:21:21 +08:00
parent
7dccd42133
commit
5c03efaf29
@ -41,6 +41,7 @@ are the CPU and GPU.
|
||||
usage/indexing
|
||||
usage/saving_and_loading
|
||||
usage/function_transforms
|
||||
usage/compile
|
||||
usage/numpy
|
||||
usage/using_streams
|
||||
|
||||
|
@ -9,6 +9,9 @@ Transforms
|
||||
:toctree: _autosummary
|
||||
|
||||
eval
|
||||
compile
|
||||
disable_compile
|
||||
enable_compile
|
||||
grad
|
||||
value_and_grad
|
||||
jvp
|
||||
|
430
docs/src/usage/compile.rst
Normal file
430
docs/src/usage/compile.rst
Normal file
@ -0,0 +1,430 @@
|
||||
.. _compile:
|
||||
|
||||
Compilation
|
||||
===========
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
MLX has a :func:`compile` function transformation which compiles computation
|
||||
graphs. Function compilation results in smaller graphs by merging common work
|
||||
and fusing certain operations. In many cases this can lead to big improvements
|
||||
in run-time and memory use.
|
||||
|
||||
Getting started with :func:`compile` is simple, but there are some edge cases
|
||||
that are good to be aware of for more complex graphs and advanced usage.
|
||||
|
||||
Basics of Compile
|
||||
-----------------
|
||||
|
||||
Let's start with a simple example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return mx.exp(-x) + y
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(2.0)
|
||||
|
||||
# Regular call, no compilation
|
||||
# Prints: array(2.36788, dtype=float32)
|
||||
print(fun(x, y))
|
||||
|
||||
# Compile the function
|
||||
compiled_fun = mx.compile(fun)
|
||||
|
||||
# Prints: array(2.36788, dtype=float32)
|
||||
print(compiled_fun(x, y))
|
||||
|
||||
The output of both the regular function and the compiled function is the same
|
||||
up to numerical precision.
|
||||
|
||||
The first time you call a compiled function, MLX will build the compute
|
||||
graph, optimize it, and generate and compile code. This can be relatively
|
||||
slow. However, MLX will cache compiled functions, so calling a compiled
|
||||
function multiple times will not initiate a new compilation. This means you
|
||||
should typically compile functions that you plan to use more than once.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return mx.exp(-x) + y
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(2.0)
|
||||
|
||||
compiled_fun = mx.compile(fun)
|
||||
|
||||
# Compiled here
|
||||
compiled_fun(x, y)
|
||||
|
||||
# Not compiled again
|
||||
compiled_fun(x, y)
|
||||
|
||||
# Not compiled again
|
||||
mx.compile(fun)(x, y)
|
||||
|
||||
There are some important cases to be aware of that can cause a function to
|
||||
be recompiled:
|
||||
|
||||
* Changing the shape or number of dimensions
|
||||
* Changing the type of any of the inputs
|
||||
* Changing the number of inputs to the function
|
||||
|
||||
In certain cases only some of the compilation stack will be rerun (for
|
||||
example when changing the shapes) and in other cases the full compilation
|
||||
stack will be rerun (for example when changing the types). In general you
|
||||
should avoid compiling functions too frequently.
|
||||
|
||||
Another idiom to watch out for is compiling functions which get created and
|
||||
destroyed frequently. This can happen, for example, when compiling an anonymous
|
||||
function in a loop:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
a = mx.array(1.0)
|
||||
# Don't do this, compiles lambda at each iteration
|
||||
for _ in range(5):
|
||||
mx.compile(lambda x: mx.exp(mx.abs(x)))(a)
|
||||
|
||||
Example Speedup
|
||||
---------------
|
||||
|
||||
The :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with
|
||||
Transformer-based models. The implementation involves several unary and binary
|
||||
element-wise operations:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def gelu(x):
|
||||
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||
|
||||
If you use this function with small arrays, it will be overhead bound. If you
|
||||
use it with large arrays it will be memory bandwidth bound. However, all of
|
||||
the operations in the ``gelu`` are fusible into a single kernel with
|
||||
:func:`compile`. This can speedup both cases considerably.
|
||||
|
||||
Let's compare the runtime of the regular function versus the compiled
|
||||
function. We'll use the following timing helper which does a warm up and
|
||||
handles synchronization:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import time
|
||||
|
||||
def timeit(fun, x):
|
||||
# warm up
|
||||
for _ in range(10):
|
||||
mx.eval(fun(x))
|
||||
|
||||
tic = time.perf_counter()
|
||||
for _ in range(100):
|
||||
mx.eval(fun(x))
|
||||
toc = time.perf_counter()
|
||||
tpi = 1e3 * (toc - tic) / 100
|
||||
print(f"Time per iteration {tpi:.3f} (ms)")
|
||||
|
||||
|
||||
Now make an array, and benchmark both functions:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
x = mx.random.uniform(shape=(32, 1000, 4096))
|
||||
timeit(nn.gelu, x)
|
||||
timeit(mx.compile(nn.gelu), x)
|
||||
|
||||
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
|
||||
five times faster.
|
||||
|
||||
.. note::
|
||||
|
||||
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
|
||||
functions can still be helpful, but won't typically result in as large a
|
||||
speedup as compiling operations that run on the GPU.
|
||||
|
||||
|
||||
Debugging
|
||||
---------
|
||||
|
||||
When a compiled function is first called, it is traced with placeholder
|
||||
inputs. This means you can't evaluate arrays (for example to print their
|
||||
contents) inside compiled functions.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@mx.compile
|
||||
def fun(x):
|
||||
z = -x
|
||||
print(z) # Crash
|
||||
return mx.exp(z)
|
||||
|
||||
fun(mx.array(5.0))
|
||||
|
||||
For debugging, inspecting arrays can be helpful. One way to do that is to
|
||||
globally disable compilation using the :func:`disable_compile` function or
|
||||
``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though
|
||||
``fun`` is compiled:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@mx.compile
|
||||
def fun(x):
|
||||
z = -x
|
||||
print(z) # Okay
|
||||
return mx.exp(z)
|
||||
|
||||
mx.disable_compile()
|
||||
fun(mx.array(5.0))
|
||||
|
||||
|
||||
Pure Functions
|
||||
--------------
|
||||
|
||||
Compiled functions are intended to be *pure*; that is they should not have side
|
||||
effects. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
state = []
|
||||
|
||||
@mx.compile
|
||||
def fun(x, y):
|
||||
z = x + y
|
||||
state.append(z)
|
||||
return mx.exp(z)
|
||||
|
||||
fun(mx.array(1.0), mx.array(2.0))
|
||||
# Crash!
|
||||
print(state)
|
||||
|
||||
After the first call of ``fun``, the ``state`` list will hold a placeholder
|
||||
array. The placeholder does not have any data; it is only used to build the
|
||||
computation graph. Printing such an array results in a crash.
|
||||
|
||||
You have two options to deal with this. The first option is to simply return
|
||||
``state`` as an output:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
state = []
|
||||
|
||||
@mx.compile
|
||||
def fun(x, y):
|
||||
z = x + y
|
||||
state.append(z)
|
||||
return mx.exp(z), state
|
||||
|
||||
_, state = fun(mx.array(1.0), mx.array(2.0))
|
||||
# Prints [array(3, dtype=float32)]
|
||||
print(state)
|
||||
|
||||
In some cases returning updated state can be pretty inconvenient. Hence,
|
||||
:func:`compile` has a parameter to capture implicit outputs:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from functools import partial
|
||||
|
||||
state = []
|
||||
|
||||
# Tell compile to capture state as an output
|
||||
@partial(mx.compile, outputs=state)
|
||||
def fun(x, y):
|
||||
z = x + y
|
||||
state.append(z)
|
||||
return mx.exp(z), state
|
||||
|
||||
fun(mx.array(1.0), mx.array(2.0))
|
||||
# Prints [array(3, dtype=float32)]
|
||||
print(state)
|
||||
|
||||
This is particularly useful for compiling a function which includes an update
|
||||
to a container of arrays, as is commonly done when training the parameters of a
|
||||
:class:`mlx.nn.Module`.
|
||||
|
||||
Compiled functions will also treat any inputs not in the parameter list as
|
||||
constants. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
state = [mx.array(1.0)]
|
||||
|
||||
@mx.compile
|
||||
def fun(x):
|
||||
return x + state[0]
|
||||
|
||||
# Prints array(2, dtype=float32)
|
||||
print(fun(mx.array(1.0)))
|
||||
|
||||
# Update state
|
||||
state[0] = mx.array(5.0)
|
||||
|
||||
# Still prints array(2, dtype=float32)
|
||||
print(fun(mx.array(1.0)))
|
||||
|
||||
In order to have the change of state reflected in the outputs of ``fun`` you
|
||||
again have two options. The first option is to simply pass ``state`` as input
|
||||
to the function. In some cases this can be pretty inconvenient. Hence,
|
||||
:func:`compile` also has a parameter to capture implicit inputs:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from functools import partial
|
||||
state = [mx.array(1.0)]
|
||||
|
||||
# Tell compile to capture state as an input
|
||||
@partial(mx.compile, inputs=state)
|
||||
def fun(x):
|
||||
return x + state[0]
|
||||
|
||||
# Prints array(2, dtype=float32)
|
||||
print(fun(mx.array(1.0)))
|
||||
|
||||
# Update state
|
||||
state[0] = mx.array(5.0)
|
||||
|
||||
# Prints array(6, dtype=float32)
|
||||
print(fun(mx.array(1.0)))
|
||||
|
||||
|
||||
Compiling Training Graphs
|
||||
-------------------------
|
||||
|
||||
This section will step through how to use :func:`compile` with a simple example
|
||||
of a common setup: training a model with :obj:`mlx.nn.Module` using an
|
||||
:obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the
|
||||
full forward, backward, and update with :func:`compile`.
|
||||
|
||||
To start, here is the simple example without any compilation:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
|
||||
# 4 examples with 10 features each
|
||||
x = mx.random.uniform(shape=(4, 10))
|
||||
|
||||
# 0, 1 targets
|
||||
y = mx.array([0, 1, 0, 1])
|
||||
|
||||
# Simple linear model
|
||||
model = nn.Linear(10, 1)
|
||||
|
||||
# SGD with momentum
|
||||
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
|
||||
|
||||
def loss_fn(model, x, y):
|
||||
logits = model(x).squeeze()
|
||||
return nn.losses.binary_cross_entropy(logits, y)
|
||||
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
|
||||
# Perform 10 steps of gradient descent
|
||||
for it in range(10):
|
||||
loss, grads = loss_and_grad_fn(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
|
||||
To compile the update we can put it all in a function and compile it with the
|
||||
appropriate input and output captures. Here's the same example but compiled:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
from functools import partial
|
||||
|
||||
# 4 examples with 10 features each
|
||||
x = mx.random.uniform(shape=(4, 10))
|
||||
|
||||
# 0, 1 targets
|
||||
y = mx.array([0, 1, 0, 1])
|
||||
|
||||
# Simple linear model
|
||||
model = nn.Linear(10, 1)
|
||||
|
||||
# SGD with momentum
|
||||
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
|
||||
|
||||
def loss_fn(model, x, y):
|
||||
logits = model(x).squeeze()
|
||||
return nn.losses.binary_cross_entropy(logits, y)
|
||||
|
||||
# The state that will be captured as input and output
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(x, y):
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
loss, grads = loss_and_grad_fn(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
# Perform 10 steps of gradient descent
|
||||
for it in range(10):
|
||||
loss = step(x, y)
|
||||
# Evaluate the model and optimizer state
|
||||
mx.eval(state)
|
||||
print(loss)
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
If you are using a module which performs random sampling such as
|
||||
:func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the
|
||||
``state`` captured by :func:`compile`, i.e. ``state = [model.state,
|
||||
optimizer.state, mx.random.state]``.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
For more examples of compiling full training graphs checkout the `MLX
|
||||
Examples <https://github.com/ml-explore/mlx-examples>`_ GitHub repo.
|
||||
|
||||
Transformations with Compile
|
||||
----------------------------
|
||||
|
||||
In MLX function transformations are composable. You can apply any function
|
||||
transformation to the output of any other function transformation. For more on
|
||||
this, see the documentation on :ref:`function transforms
|
||||
<function_transforms>`.
|
||||
|
||||
Compiling transformed functions works just as expected:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
grad_fn = mx.grad(mx.exp)
|
||||
|
||||
compiled_grad_fn = mx.compile(grad_fn)
|
||||
|
||||
# Prints: array(2.71828, dtype=float32)
|
||||
print(grad_fn(mx.array(1.0)))
|
||||
|
||||
# Also prints: array(2.71828, dtype=float32)
|
||||
print(compiled_grad_fn(mx.array(1.0)))
|
||||
|
||||
.. note::
|
||||
|
||||
In order to compile as much as possible, a transformation of a compiled
|
||||
function will not by default be compiled. To compile the transformed
|
||||
function simply pass it through :func:`compile`.
|
||||
|
||||
You can also compile functions which themselves call compiled functions. A
|
||||
good practice is to compile the outer most function to give :func:`compile`
|
||||
the most opportunity to optimize the computation graph:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@mx.compile
|
||||
def inner(x):
|
||||
return mx.exp(-mx.abs(x))
|
||||
|
||||
def outer(x):
|
||||
inner(inner(x))
|
||||
|
||||
# Compiling the outer function is good to do as it will likely
|
||||
# be faster even though the inner functions are compiled
|
||||
fun = mx.compile(outer)
|
@ -5,9 +5,12 @@ Function Transforms
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
MLX uses composable function transformations for automatic differentiation and
|
||||
vectorization. The key idea behind composable function transformations is that
|
||||
every transformation returns a function which can be further transformed.
|
||||
MLX uses composable function transformations for automatic differentiation,
|
||||
vectorization, and compute graph optimizations. To see the complete list of
|
||||
function transformations check-out the :ref:`API documentation <transforms>`.
|
||||
|
||||
The key idea behind composable function transformations is that every
|
||||
transformation returns a function which can be further transformed.
|
||||
|
||||
Here is a simple example:
|
||||
|
||||
@ -36,10 +39,10 @@ Using :func:`grad` on the output of :func:`grad` is always ok. You keep
|
||||
getting higher order derivatives.
|
||||
|
||||
Any of the MLX function transformations can be composed in any order to any
|
||||
depth. To see the complete list of function transformations check-out the
|
||||
:ref:`API documentation <transforms>`. See the following sections for more
|
||||
information on :ref:`automatic differentiaion <auto diff>` and
|
||||
:ref:`automatic vectorization <vmap>`.
|
||||
depth. See the following sections for more information on :ref:`automatic
|
||||
differentiaion <auto diff>` and :ref:`automatic vectorization <vmap>`.
|
||||
For more information on :func:`compile` see the :ref:`compile documentation <compile>`.
|
||||
|
||||
|
||||
Automatic Differentiation
|
||||
-------------------------
|
||||
|
@ -1008,7 +1008,7 @@ void init_transforms(py::module_& m) {
|
||||
"enable_compile",
|
||||
&enable_compile,
|
||||
R"pbdoc(
|
||||
enable_compiler() -> None
|
||||
enable_compile() -> None
|
||||
|
||||
Globally enable compilation. This will override the environment
|
||||
variable ``MLX_DISABLE_COMPILE`` if set.
|
||||
|
Loading…
Reference in New Issue
Block a user