mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 19:38:16 +08:00
rebase
This commit is contained in:
23
docs/build/html/_sources/usage/compile.rst
vendored
23
docs/build/html/_sources/usage/compile.rst
vendored
@@ -33,12 +33,12 @@ Let's start with a simple example:
|
||||
# Compile the function
|
||||
compiled_fun = mx.compile(fun)
|
||||
|
||||
# Prints: array(2.36788, dtype=float32)
|
||||
# Prints: array(2.36788, dtype=float32)
|
||||
print(compiled_fun(x, y))
|
||||
|
||||
The output of both the regular function and the compiled function is the same
|
||||
up to numerical precision.
|
||||
|
||||
|
||||
The first time you call a compiled function, MLX will build the compute
|
||||
graph, optimize it, and generate and compile code. This can be relatively
|
||||
slow. However, MLX will cache compiled functions, so calling a compiled
|
||||
@@ -96,7 +96,7 @@ element-wise operations:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def gelu(x):
|
||||
def gelu(x):
|
||||
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||
|
||||
If you use this function with small arrays, it will be overhead bound. If you
|
||||
@@ -136,13 +136,6 @@ Now make an array, and benchmark both functions:
|
||||
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
|
||||
five times faster.
|
||||
|
||||
.. note::
|
||||
|
||||
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
|
||||
functions can still be helpful, but won't typically result in as large a
|
||||
speedup as compiling operations that run on the GPU.
|
||||
|
||||
|
||||
Debugging
|
||||
---------
|
||||
|
||||
@@ -287,7 +280,7 @@ to the function. In some cases this can be pretty inconvenient. Hence,
|
||||
print(fun(mx.array(1.0)))
|
||||
|
||||
|
||||
Compiling Training Graphs
|
||||
Compiling Training Graphs
|
||||
-------------------------
|
||||
|
||||
This section will step through how to use :func:`compile` with a simple example
|
||||
@@ -297,7 +290,7 @@ full forward, backward, and update with :func:`compile`.
|
||||
|
||||
To start, here is the simple example without any compilation:
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -330,7 +323,7 @@ To start, here is the simple example without any compilation:
|
||||
To compile the update we can put it all in a function and compile it with the
|
||||
appropriate input and output captures. Here's the same example but compiled:
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -355,7 +348,7 @@ appropriate input and output captures. Here's the same example but compiled:
|
||||
|
||||
# The state that will be captured as input and output
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(x, y):
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
@@ -410,7 +403,7 @@ Compiling transformed functions works just as expected:
|
||||
|
||||
In order to compile as much as possible, a transformation of a compiled
|
||||
function will not by default be compiled. To compile the transformed
|
||||
function simply pass it through :func:`compile`.
|
||||
function simply pass it through :func:`compile`.
|
||||
|
||||
You can also compile functions which themselves call compiled functions. A
|
||||
good practice is to compile the outer most function to give :func:`compile`
|
||||
|
Reference in New Issue
Block a user