mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			498 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			ReStructuredText
		
	
	
	
	
	
			
		
		
	
	
			498 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			ReStructuredText
		
	
	
	
	
	
.. _compile:
 | 
						|
 | 
						|
Compilation
 | 
						|
===========
 | 
						|
 | 
						|
.. currentmodule:: mlx.core
 | 
						|
 | 
						|
MLX has a :func:`compile` function transformation which compiles computation
 | 
						|
graphs. Function compilation results in smaller graphs by merging common work
 | 
						|
and fusing certain operations. In many cases this can lead to big improvements
 | 
						|
in run-time and memory use.
 | 
						|
 | 
						|
Getting started with :func:`compile` is simple, but there are some edge cases
 | 
						|
that are good to be aware of for more complex graphs and advanced usage.
 | 
						|
 | 
						|
Basics of Compile
 | 
						|
-----------------
 | 
						|
 | 
						|
Let's start with a simple example:
 | 
						|
 | 
						|
.. code-block:: python
 | 
						|
 | 
						|
  def fun(x, y):
 | 
						|
      return mx.exp(-x) + y
 | 
						|
 | 
						|
  x = mx.array(1.0)
 | 
						|
  y = mx.array(2.0)
 | 
						|
 | 
						|
  # Regular call, no compilation
 | 
						|
  # Prints: array(2.36788, dtype=float32)
 | 
						|
  print(fun(x, y))
 | 
						|
 | 
						|
  # Compile the function
 | 
						|
  compiled_fun = mx.compile(fun)
 | 
						|
 | 
						|
  # Prints: array(2.36788, dtype=float32)
 | 
						|
  print(compiled_fun(x, y))
 | 
						|
 | 
						|
The output of both the regular function and the compiled function is the same
 | 
						|
up to numerical precision.
 | 
						|
 | 
						|
The first time you call a compiled function, MLX will build the compute
 | 
						|
graph, optimize it, and generate and compile code. This can be relatively
 | 
						|
slow. However, MLX will cache compiled functions, so calling a compiled
 | 
						|
function multiple times will not initiate a new compilation. This means you
 | 
						|
should typically compile functions that you plan to use more than once.
 | 
						|
 | 
						|
.. code-block:: python
 | 
						|
 | 
						|
  def fun(x, y):
 | 
						|
      return mx.exp(-x) + y
 | 
						|
 | 
						|
  x = mx.array(1.0)
 | 
						|
  y = mx.array(2.0)
 | 
						|
 | 
						|
  compiled_fun = mx.compile(fun)
 | 
						|
 | 
						|
  # Compiled here
 | 
						|
  compiled_fun(x, y)
 | 
						|
 | 
						|
  # Not compiled again
 | 
						|
  compiled_fun(x, y)
 | 
						|
 | 
						|
  # Not compiled again
 | 
						|
  mx.compile(fun)(x, y)
 | 
						|
 | 
						|
There are some important cases to be aware of that can cause a function to
 | 
						|
be recompiled:
 | 
						|
 | 
						|
* Changing the shape or number of dimensions
 | 
						|
* Changing the type of any of the inputs
 | 
						|
* Changing the number of inputs to the function
 | 
						|
 | 
						|
In certain cases only some of the compilation stack will be rerun (for
 | 
						|
example when changing the shapes) and in other cases the full compilation
 | 
						|
stack will be rerun (for example when changing the types). In general you
 | 
						|
should avoid compiling functions too frequently.
 | 
						|
 | 
						|
Another idiom to watch out for is compiling functions which get created and
 | 
						|
destroyed frequently. This can happen, for example, when compiling an anonymous
 | 
						|
function in a loop:
 | 
						|
 | 
						|
.. code-block:: python
 | 
						|
 | 
						|
  a = mx.array(1.0)
 | 
						|
  # Don't do this, compiles lambda at each iteration
 | 
						|
  for _ in range(5):
 | 
						|
      mx.compile(lambda x: mx.exp(mx.abs(x)))(a)
 | 
						|
 | 
						|
Example Speedup
 | 
						|
---------------
 | 
						|
 | 
						|
The :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with
 | 
						|
Transformer-based models. The implementation involves several unary and binary
 | 
						|
element-wise operations:
 | 
						|
 | 
						|
.. code-block:: python
 | 
						|
 | 
						|
  def gelu(x):
 | 
						|
      return x * (1 + mx.erf(x / math.sqrt(2))) / 2
 | 
						|
 | 
						|
If you use this function with small arrays, it will be overhead bound. If you
 | 
						|
use it with large arrays it will be memory bandwidth bound.  However, all of
 | 
						|
the operations in the ``gelu`` are fusible into a single kernel with
 | 
						|
:func:`compile`. This can speedup both cases considerably.
 | 
						|
 | 
						|
Let's compare the runtime of the regular function versus the compiled
 | 
						|
function. We'll use the following timing helper which does a warm up and
 | 
						|
handles synchronization:
 | 
						|
 | 
						|
.. code-block:: python
 | 
						|
 | 
						|
  import time
 | 
						|
 | 
						|
  def timeit(fun, x):
 | 
						|
      # warm up
 | 
						|
      for _ in range(10):
 | 
						|
          mx.eval(fun(x))
 | 
						|
 | 
						|
      tic = time.perf_counter()
 | 
						|
      for _ in range(100):
 | 
						|
          mx.eval(fun(x))
 | 
						|
      toc = time.perf_counter()
 | 
						|
      tpi = 1e3 * (toc - tic) / 100
 | 
						|
      print(f"Time per iteration {tpi:.3f} (ms)")
 | 
						|
 | 
						|
 | 
						|
Now make an array, and benchmark both functions:
 | 
						|
 | 
						|
.. code-block:: python
 | 
						|
 | 
						|
  x = mx.random.uniform(shape=(32, 1000, 4096))
 | 
						|
  timeit(nn.gelu, x)
 | 
						|
  timeit(mx.compile(nn.gelu), x)
 | 
						|
 | 
						|
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
 | 
						|
five times faster.
 | 
						|
 | 
						|
Debugging
 | 
						|
---------
 | 
						|
 | 
						|
When a compiled function is first called, it is traced with placeholder
 | 
						|
inputs. This means you can't evaluate arrays (for example to print their
 | 
						|
contents) inside compiled functions.
 | 
						|
 | 
						|
.. code-block:: python
 | 
						|
 | 
						|
  @mx.compile
 | 
						|
  def fun(x):
 | 
						|
      z = -x
 | 
						|
      print(z)  # Crash
 | 
						|
      return mx.exp(z)
 | 
						|
 | 
						|
  fun(mx.array(5.0))
 | 
						|
 | 
						|
For debugging, inspecting arrays can be helpful. One way to do that is to
 | 
						|
globally disable compilation using the :func:`disable_compile` function or
 | 
						|
``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though
 | 
						|
``fun`` is compiled:
 | 
						|
 | 
						|
.. code-block:: python
 | 
						|
 | 
						|
  @mx.compile
 | 
						|
  def fun(x):
 | 
						|
      z = -x
 | 
						|
      print(z) # Okay
 | 
						|
      return mx.exp(z)
 | 
						|
 | 
						|
  mx.disable_compile()
 | 
						|
  fun(mx.array(5.0))
 | 
						|
 | 
						|
 | 
						|
Pure Functions
 | 
						|
--------------
 | 
						|
 | 
						|
Compiled functions are intended to be *pure*; that is they should not have side
 | 
						|
effects. For example:
 | 
						|
 | 
						|
.. code-block:: python
 | 
						|
 | 
						|
  state = []
 | 
						|
 | 
						|
  @mx.compile
 | 
						|
  def fun(x, y):
 | 
						|
      z = x + y
 | 
						|
      state.append(z)
 | 
						|
      return mx.exp(z)
 | 
						|
 | 
						|
  fun(mx.array(1.0), mx.array(2.0))
 | 
						|
  # Crash!
 | 
						|
  print(state)
 | 
						|
 | 
						|
After the first call of ``fun``, the ``state`` list will hold a placeholder
 | 
						|
array. The placeholder does not have any data; it is only used to build the
 | 
						|
computation graph. Printing such an array results in a crash.
 | 
						|
 | 
						|
You have two options to deal with this. The first option is to simply return
 | 
						|
``state`` as an output:
 | 
						|
 | 
						|
.. code-block:: python
 | 
						|
 | 
						|
   state = []
 | 
						|
 | 
						|
   @mx.compile
 | 
						|
   def fun(x, y):
 | 
						|
      z = x + y
 | 
						|
      state.append(z)
 | 
						|
      return mx.exp(z), state
 | 
						|
 | 
						|
    _, state = fun(mx.array(1.0), mx.array(2.0))
 | 
						|
    # Prints [array(3, dtype=float32)]
 | 
						|
    print(state)
 | 
						|
 | 
						|
In some cases returning updated state can be pretty inconvenient. Hence,
 | 
						|
:func:`compile` has a parameter to capture implicit outputs:
 | 
						|
 | 
						|
.. code-block:: python
 | 
						|
 | 
						|
  from functools import partial
 | 
						|
 | 
						|
  state = []
 | 
						|
 | 
						|
  # Tell compile to capture state as an output
 | 
						|
  @partial(mx.compile, outputs=state)
 | 
						|
  def fun(x, y):
 | 
						|
      z = x + y
 | 
						|
      state.append(z)
 | 
						|
      return mx.exp(z), state
 | 
						|
 | 
						|
  fun(mx.array(1.0), mx.array(2.0))
 | 
						|
  # Prints [array(3, dtype=float32)]
 | 
						|
  print(state)
 | 
						|
 | 
						|
This is particularly useful for compiling a function which includes an update
 | 
						|
to a container of arrays, as is commonly done when training the parameters of a
 | 
						|
:class:`mlx.nn.Module`.
 | 
						|
 | 
						|
Compiled functions will also treat any inputs not in the parameter list as
 | 
						|
constants. For example:
 | 
						|
 | 
						|
.. code-block:: python
 | 
						|
 | 
						|
  state = [mx.array(1.0)]
 | 
						|
 | 
						|
  @mx.compile
 | 
						|
  def fun(x):
 | 
						|
      return x + state[0]
 | 
						|
 | 
						|
  # Prints array(2, dtype=float32)
 | 
						|
  print(fun(mx.array(1.0)))
 | 
						|
 | 
						|
  # Update state
 | 
						|
  state[0] = mx.array(5.0)
 | 
						|
 | 
						|
  # Still prints array(2, dtype=float32)
 | 
						|
  print(fun(mx.array(1.0)))
 | 
						|
 | 
						|
In order to have the change of state reflected in the outputs of ``fun`` you
 | 
						|
again have two options. The first option is to simply pass ``state`` as input
 | 
						|
to the function. In some cases this can be pretty inconvenient. Hence,
 | 
						|
:func:`compile` also has a parameter to capture implicit inputs:
 | 
						|
 | 
						|
.. code-block:: python
 | 
						|
 | 
						|
  from functools import partial
 | 
						|
  state = [mx.array(1.0)]
 | 
						|
 | 
						|
  # Tell compile to capture state as an input
 | 
						|
  @partial(mx.compile, inputs=state)
 | 
						|
  def fun(x):
 | 
						|
      return x + state[0]
 | 
						|
 | 
						|
  # Prints array(2, dtype=float32)
 | 
						|
  print(fun(mx.array(1.0)))
 | 
						|
 | 
						|
  # Update state
 | 
						|
  state[0] = mx.array(5.0)
 | 
						|
 | 
						|
  # Prints array(6, dtype=float32)
 | 
						|
  print(fun(mx.array(1.0)))
 | 
						|
 | 
						|
 | 
						|
Compiling Training Graphs
 | 
						|
-------------------------
 | 
						|
 | 
						|
This section will step through how to use :func:`compile` with a simple example
 | 
						|
of a common setup: training a model with :obj:`mlx.nn.Module` using an
 | 
						|
:obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the
 | 
						|
full forward, backward, and update with :func:`compile`.
 | 
						|
 | 
						|
To start, here is the simple example without any compilation:
 | 
						|
 | 
						|
.. code-block:: python
 | 
						|
 | 
						|
  import mlx.core as mx
 | 
						|
  import mlx.nn as nn
 | 
						|
  import mlx.optimizers as optim
 | 
						|
 | 
						|
  # 4 examples with 10 features each
 | 
						|
  x = mx.random.uniform(shape=(4, 10))
 | 
						|
 | 
						|
  # 0, 1 targets
 | 
						|
  y = mx.array([0, 1, 0, 1])
 | 
						|
 | 
						|
  # Simple linear model
 | 
						|
  model = nn.Linear(10, 1)
 | 
						|
 | 
						|
  # SGD with momentum
 | 
						|
  optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
 | 
						|
 | 
						|
  def loss_fn(model, x, y):
 | 
						|
      logits = model(x).squeeze()
 | 
						|
      return nn.losses.binary_cross_entropy(logits, y)
 | 
						|
 | 
						|
  loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
 | 
						|
 | 
						|
  # Perform 10 steps of gradient descent
 | 
						|
  for it in range(10):
 | 
						|
      loss, grads = loss_and_grad_fn(model, x, y)
 | 
						|
      optimizer.update(model, grads)
 | 
						|
      mx.eval(model.parameters(), optimizer.state)
 | 
						|
 | 
						|
To compile the update we can put it all in a function and compile it with the
 | 
						|
appropriate input and output captures. Here's the same example but compiled:
 | 
						|
 | 
						|
.. code-block:: python
 | 
						|
 | 
						|
  import mlx.core as mx
 | 
						|
  import mlx.nn as nn
 | 
						|
  import mlx.optimizers as optim
 | 
						|
  from functools import partial
 | 
						|
 | 
						|
  # 4 examples with 10 features each
 | 
						|
  x = mx.random.uniform(shape=(4, 10))
 | 
						|
 | 
						|
  # 0, 1 targets
 | 
						|
  y = mx.array([0, 1, 0, 1])
 | 
						|
 | 
						|
  # Simple linear model
 | 
						|
  model = nn.Linear(10, 1)
 | 
						|
 | 
						|
  # SGD with momentum
 | 
						|
  optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
 | 
						|
 | 
						|
  def loss_fn(model, x, y):
 | 
						|
      logits = model(x).squeeze()
 | 
						|
      return nn.losses.binary_cross_entropy(logits, y)
 | 
						|
 | 
						|
  # The state that will be captured as input and output
 | 
						|
  state = [model.state, optimizer.state]
 | 
						|
 | 
						|
  @partial(mx.compile, inputs=state, outputs=state)
 | 
						|
  def step(x, y):
 | 
						|
      loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
 | 
						|
      loss, grads = loss_and_grad_fn(model, x, y)
 | 
						|
      optimizer.update(model, grads)
 | 
						|
      return loss
 | 
						|
 | 
						|
  # Perform 10 steps of gradient descent
 | 
						|
  for it in range(10):
 | 
						|
      loss = step(x, y)
 | 
						|
      # Evaluate the model and optimizer state
 | 
						|
      mx.eval(state)
 | 
						|
      print(loss)
 | 
						|
 | 
						|
 | 
						|
.. note::
 | 
						|
 | 
						|
  If you are using a module which performs random sampling such as
 | 
						|
  :func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the
 | 
						|
  ``state`` captured by :func:`compile`, i.e. ``state = [model.state,
 | 
						|
  optimizer.state, mx.random.state]``.
 | 
						|
 | 
						|
 | 
						|
.. note::
 | 
						|
 | 
						|
   For more examples of compiling full training graphs checkout the  `MLX
 | 
						|
   Examples <https://github.com/ml-explore/mlx-examples>`_ GitHub repo.
 | 
						|
 | 
						|
Transformations with Compile
 | 
						|
----------------------------
 | 
						|
 | 
						|
In MLX function transformations are composable. You can apply any function
 | 
						|
transformation to the output of any other function transformation. For more on
 | 
						|
this, see the documentation on :ref:`function transforms
 | 
						|
<function_transforms>`.
 | 
						|
 | 
						|
Compiling transformed functions works just as expected:
 | 
						|
 | 
						|
.. code-block:: python
 | 
						|
 | 
						|
  grad_fn = mx.grad(mx.exp)
 | 
						|
 | 
						|
  compiled_grad_fn = mx.compile(grad_fn)
 | 
						|
 | 
						|
  # Prints: array(2.71828, dtype=float32)
 | 
						|
  print(grad_fn(mx.array(1.0)))
 | 
						|
 | 
						|
  # Also prints: array(2.71828, dtype=float32)
 | 
						|
  print(compiled_grad_fn(mx.array(1.0)))
 | 
						|
 | 
						|
.. note::
 | 
						|
 | 
						|
   In order to compile as much as possible, a transformation of a compiled
 | 
						|
   function will not by default be compiled. To compile the transformed
 | 
						|
   function simply pass it through :func:`compile`.
 | 
						|
 | 
						|
You can also compile functions which themselves call compiled functions. A
 | 
						|
good practice is to compile the outer most function to give :func:`compile`
 | 
						|
the most opportunity to optimize the computation graph:
 | 
						|
 | 
						|
.. code-block:: python
 | 
						|
 | 
						|
  @mx.compile
 | 
						|
  def inner(x):
 | 
						|
      return mx.exp(-mx.abs(x))
 | 
						|
 | 
						|
  def outer(x):
 | 
						|
      inner(inner(x))
 | 
						|
 | 
						|
  # Compiling the outer function is good to do as it will likely
 | 
						|
  # be faster even though the inner functions are compiled
 | 
						|
  fun = mx.compile(outer)
 | 
						|
 | 
						|
 | 
						|
 | 
						|
.. _shapeless_compile:
 | 
						|
 | 
						|
Shapeless Compilation
 | 
						|
---------------------
 | 
						|
 | 
						|
When the shape of an input to a compiled function changes, the function is
 | 
						|
recompiled. You can compile a function once and run it on inputs with
 | 
						|
variable shapes by specifying ``shapeless=True`` to :func:`compile`. In this
 | 
						|
case changes to the shapes of the inputs do not cause the function to be
 | 
						|
recompiled.
 | 
						|
 | 
						|
.. code-block:: python
 | 
						|
 | 
						|
  def fun(x, y):
 | 
						|
      return mx.abs(x + y)
 | 
						|
 | 
						|
  compiled_fun = mx.compile(fun, shapeless=True)
 | 
						|
 | 
						|
  x = mx.array(1.0)
 | 
						|
  y = mx.array(-2.0)
 | 
						|
 | 
						|
  # Firt call compiles the function
 | 
						|
  print(compiled_fun(x, y))
 | 
						|
 | 
						|
  # Second call with different shapes
 | 
						|
  # does not recompile the function
 | 
						|
  x = mx.array([1.0, -6.0])
 | 
						|
  y = mx.array([-2.0, 3.0])
 | 
						|
  print(compiled_fun(x, y))
 | 
						|
 | 
						|
 | 
						|
Use shapeless compilations carefully. Since compilation is not triggered when
 | 
						|
shapes change, any graphs which are conditional on the input shapes will not
 | 
						|
work as expected. Shape-dependent computations are common and sometimes subtle
 | 
						|
to detect. For example:
 | 
						|
 | 
						|
.. code-block:: python
 | 
						|
 | 
						|
  def fun(x):
 | 
						|
      return x.reshape(x.shape[0] * x.shape[1], -1)
 | 
						|
 | 
						|
  compiled_fun = mx.compile(fun, shapeless=True)
 | 
						|
 | 
						|
  x = mx.random.uniform(shape=(2, 3, 4))
 | 
						|
 | 
						|
  out = compiled_fun(x)
 | 
						|
 | 
						|
  x = mx.random.uniform(shape=(5, 5, 3))
 | 
						|
 | 
						|
  # Error, can't reshape (5, 5, 3) to (6, -1)
 | 
						|
  out = compiled_fun(x)
 | 
						|
 | 
						|
The second call to the ``compiled_fun`` fails because of the call to
 | 
						|
:func:`reshape` which uses the static shape of ``x`` in the first call. We can
 | 
						|
fix this by using :func:`flatten` to avoid hardcoding the shape of ``x``:
 | 
						|
 | 
						|
.. code-block:: python
 | 
						|
 | 
						|
  def fun(x):
 | 
						|
      return x.flatten(0, 1)
 | 
						|
 | 
						|
  compiled_fun = mx.compile(fun, shapeless=True)
 | 
						|
 | 
						|
  x = mx.random.uniform(shape=(2, 3, 4))
 | 
						|
 | 
						|
  out = compiled_fun(x)
 | 
						|
 | 
						|
  x = mx.random.uniform(shape=(5, 5, 3))
 | 
						|
 | 
						|
  # Ok
 | 
						|
  out = compiled_fun(x)
 |