mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	| @@ -41,6 +41,7 @@ are the CPU and GPU. | ||||
|    usage/indexing | ||||
|    usage/saving_and_loading | ||||
|    usage/function_transforms | ||||
|    usage/compile | ||||
|    usage/numpy | ||||
|    usage/using_streams | ||||
|  | ||||
|   | ||||
| @@ -9,6 +9,9 @@ Transforms | ||||
|   :toctree: _autosummary | ||||
|  | ||||
|    eval | ||||
|    compile | ||||
|    disable_compile | ||||
|    enable_compile | ||||
|    grad | ||||
|    value_and_grad | ||||
|    jvp | ||||
|   | ||||
							
								
								
									
										430
									
								
								docs/src/usage/compile.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										430
									
								
								docs/src/usage/compile.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,430 @@ | ||||
| .. _compile: | ||||
|  | ||||
| Compilation | ||||
| =========== | ||||
|  | ||||
| .. currentmodule:: mlx.core | ||||
|  | ||||
| MLX has a :func:`compile` function transformation which compiles computation | ||||
| graphs. Function compilation results in smaller graphs by merging common work | ||||
| and fusing certain operations. In many cases this can lead to big improvements | ||||
| in run-time and memory use. | ||||
|  | ||||
| Getting started with :func:`compile` is simple, but there are some edge cases | ||||
| that are good to be aware of for more complex graphs and advanced usage. | ||||
|  | ||||
| Basics of Compile | ||||
| ----------------- | ||||
|  | ||||
| Let's start with a simple example: | ||||
|  | ||||
| .. code-block:: python | ||||
|  | ||||
|   def fun(x, y): | ||||
|       return mx.exp(-x) + y | ||||
|  | ||||
|   x = mx.array(1.0) | ||||
|   y = mx.array(2.0) | ||||
|  | ||||
|   # Regular call, no compilation | ||||
|   # Prints: array(2.36788, dtype=float32) | ||||
|   print(fun(x, y)) | ||||
|  | ||||
|   # Compile the function | ||||
|   compiled_fun = mx.compile(fun) | ||||
|  | ||||
|   # Prints: array(2.36788, dtype=float32)  | ||||
|   print(compiled_fun(x, y)) | ||||
|  | ||||
| The output of both the regular function and the compiled function is the same | ||||
| up to numerical precision. | ||||
|     | ||||
| The first time you call a compiled function, MLX will build the compute | ||||
| graph, optimize it, and generate and compile code. This can be relatively | ||||
| slow. However, MLX will cache compiled functions, so calling a compiled | ||||
| function multiple times will not initiate a new compilation. This means you | ||||
| should typically compile functions that you plan to use more than once. | ||||
|  | ||||
| .. code-block:: python | ||||
|  | ||||
|   def fun(x, y): | ||||
|       return mx.exp(-x) + y | ||||
|  | ||||
|   x = mx.array(1.0) | ||||
|   y = mx.array(2.0) | ||||
|  | ||||
|   compiled_fun = mx.compile(fun) | ||||
|  | ||||
|   # Compiled here | ||||
|   compiled_fun(x, y) | ||||
|  | ||||
|   # Not compiled again | ||||
|   compiled_fun(x, y) | ||||
|  | ||||
|   # Not compiled again | ||||
|   mx.compile(fun)(x, y) | ||||
|  | ||||
| There are some important cases to be aware of that can cause a function to | ||||
| be recompiled: | ||||
|  | ||||
| * Changing the shape or number of dimensions | ||||
| * Changing the type of any of the inputs | ||||
| * Changing the number of inputs to the function | ||||
|  | ||||
| In certain cases only some of the compilation stack will be rerun (for | ||||
| example when changing the shapes) and in other cases the full compilation | ||||
| stack will be rerun (for example when changing the types). In general you | ||||
| should avoid compiling functions too frequently. | ||||
|  | ||||
| Another idiom to watch out for is compiling functions which get created and | ||||
| destroyed frequently. This can happen, for example, when compiling an anonymous | ||||
| function in a loop: | ||||
|  | ||||
| .. code-block:: python | ||||
|  | ||||
|   a = mx.array(1.0) | ||||
|   # Don't do this, compiles lambda at each iteration | ||||
|   for _ in range(5): | ||||
|       mx.compile(lambda x: mx.exp(mx.abs(x)))(a) | ||||
|  | ||||
| Example Speedup | ||||
| --------------- | ||||
|  | ||||
| The :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with | ||||
| Transformer-based models. The implementation involves several unary and binary | ||||
| element-wise operations: | ||||
|  | ||||
| .. code-block:: python | ||||
|  | ||||
|   def gelu(x):   | ||||
|       return x * (1 + mx.erf(x / math.sqrt(2))) / 2 | ||||
|  | ||||
| If you use this function with small arrays, it will be overhead bound. If you | ||||
| use it with large arrays it will be memory bandwidth bound.  However, all of | ||||
| the operations in the ``gelu`` are fusible into a single kernel with | ||||
| :func:`compile`. This can speedup both cases considerably. | ||||
|  | ||||
| Let's compare the runtime of the regular function versus the compiled | ||||
| function. We'll use the following timing helper which does a warm up and | ||||
| handles synchronization: | ||||
|  | ||||
| .. code-block:: python | ||||
|  | ||||
|   import time | ||||
|  | ||||
|   def timeit(fun, x): | ||||
|       # warm up | ||||
|       for _ in range(10): | ||||
|           mx.eval(fun(x)) | ||||
|  | ||||
|       tic = time.perf_counter() | ||||
|       for _ in range(100): | ||||
|           mx.eval(fun(x)) | ||||
|       toc = time.perf_counter() | ||||
|       tpi = 1e3 * (toc - tic) / 100 | ||||
|       print(f"Time per iteration {tpi:.3f} (ms)") | ||||
|  | ||||
|  | ||||
| Now make an array, and benchmark both functions: | ||||
|  | ||||
| .. code-block:: python | ||||
|  | ||||
|   x = mx.random.uniform(shape=(32, 1000, 4096)) | ||||
|   timeit(nn.gelu, x) | ||||
|   timeit(mx.compile(nn.gelu), x) | ||||
|  | ||||
| On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is | ||||
| five times faster. | ||||
|  | ||||
| .. note:: | ||||
|  | ||||
|   As of the latest MLX, CPU functions are not fully compiled. Compiling CPU | ||||
|   functions can still be helpful, but won't typically result in as large a | ||||
|   speedup as compiling operations that run on the GPU. | ||||
|  | ||||
|  | ||||
| Debugging | ||||
| --------- | ||||
|  | ||||
| When a compiled function is first called, it is traced with placeholder | ||||
| inputs. This means you can't evaluate arrays (for example to print their | ||||
| contents) inside compiled functions. | ||||
|  | ||||
| .. code-block:: python | ||||
|  | ||||
|   @mx.compile | ||||
|   def fun(x): | ||||
|       z = -x | ||||
|       print(z)  # Crash | ||||
|       return mx.exp(z) | ||||
|  | ||||
|   fun(mx.array(5.0)) | ||||
|  | ||||
| For debugging, inspecting arrays can be helpful. One way to do that is to | ||||
| globally disable compilation using the :func:`disable_compile` function or | ||||
| ``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though | ||||
| ``fun`` is compiled: | ||||
|  | ||||
| .. code-block:: python | ||||
|  | ||||
|   @mx.compile | ||||
|   def fun(x): | ||||
|       z = -x | ||||
|       print(z) # Okay | ||||
|       return mx.exp(z) | ||||
|  | ||||
|   mx.disable_compile() | ||||
|   fun(mx.array(5.0)) | ||||
|  | ||||
|  | ||||
| Pure Functions | ||||
| -------------- | ||||
|  | ||||
| Compiled functions are intended to be *pure*; that is they should not have side | ||||
| effects. For example: | ||||
|  | ||||
| .. code-block:: python | ||||
|  | ||||
|   state = [] | ||||
|  | ||||
|   @mx.compile | ||||
|   def fun(x, y): | ||||
|       z = x + y | ||||
|       state.append(z) | ||||
|       return mx.exp(z) | ||||
|  | ||||
|   fun(mx.array(1.0), mx.array(2.0)) | ||||
|   # Crash! | ||||
|   print(state) | ||||
|  | ||||
| After the first call of ``fun``, the ``state`` list will hold a placeholder | ||||
| array. The placeholder does not have any data; it is only used to build the | ||||
| computation graph. Printing such an array results in a crash. | ||||
|  | ||||
| You have two options to deal with this. The first option is to simply return | ||||
| ``state`` as an output: | ||||
|  | ||||
| .. code-block:: python | ||||
|  | ||||
|    state = [] | ||||
|  | ||||
|    @mx.compile | ||||
|    def fun(x, y): | ||||
|       z = x + y | ||||
|       state.append(z) | ||||
|       return mx.exp(z), state | ||||
|  | ||||
|     _, state = fun(mx.array(1.0), mx.array(2.0)) | ||||
|     # Prints [array(3, dtype=float32)] | ||||
|     print(state) | ||||
|  | ||||
| In some cases returning updated state can be pretty inconvenient. Hence, | ||||
| :func:`compile` has a parameter to capture implicit outputs: | ||||
|  | ||||
| .. code-block:: python | ||||
|  | ||||
|   from functools import partial | ||||
|  | ||||
|   state = [] | ||||
|  | ||||
|   # Tell compile to capture state as an output | ||||
|   @partial(mx.compile, outputs=state) | ||||
|   def fun(x, y): | ||||
|       z = x + y | ||||
|       state.append(z) | ||||
|       return mx.exp(z), state | ||||
|  | ||||
|   fun(mx.array(1.0), mx.array(2.0)) | ||||
|   # Prints [array(3, dtype=float32)] | ||||
|   print(state) | ||||
|  | ||||
| This is particularly useful for compiling a function which includes an update | ||||
| to a container of arrays, as is commonly done when training the parameters of a | ||||
| :class:`mlx.nn.Module`. | ||||
|  | ||||
| Compiled functions will also treat any inputs not in the parameter list as | ||||
| constants. For example: | ||||
|  | ||||
| .. code-block:: python | ||||
|  | ||||
|   state = [mx.array(1.0)] | ||||
|  | ||||
|   @mx.compile | ||||
|   def fun(x): | ||||
|       return x + state[0] | ||||
|  | ||||
|   # Prints array(2, dtype=float32) | ||||
|   print(fun(mx.array(1.0))) | ||||
|  | ||||
|   # Update state | ||||
|   state[0] = mx.array(5.0) | ||||
|  | ||||
|   # Still prints array(2, dtype=float32) | ||||
|   print(fun(mx.array(1.0))) | ||||
|  | ||||
| In order to have the change of state reflected in the outputs of ``fun`` you | ||||
| again have two options. The first option is to simply pass ``state`` as input | ||||
| to the function. In some cases this can be pretty inconvenient. Hence, | ||||
| :func:`compile` also has a parameter to capture implicit inputs: | ||||
|  | ||||
| .. code-block:: python | ||||
|  | ||||
|   from functools import partial | ||||
|   state = [mx.array(1.0)] | ||||
|  | ||||
|   # Tell compile to capture state as an input | ||||
|   @partial(mx.compile, inputs=state) | ||||
|   def fun(x): | ||||
|       return x + state[0] | ||||
|  | ||||
|   # Prints array(2, dtype=float32) | ||||
|   print(fun(mx.array(1.0))) | ||||
|  | ||||
|   # Update state | ||||
|   state[0] = mx.array(5.0) | ||||
|  | ||||
|   # Prints array(6, dtype=float32) | ||||
|   print(fun(mx.array(1.0))) | ||||
|  | ||||
|  | ||||
| Compiling Training Graphs  | ||||
| ------------------------- | ||||
|  | ||||
| This section will step through how to use :func:`compile` with a simple example | ||||
| of a common setup: training a model with :obj:`mlx.nn.Module` using an | ||||
| :obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the | ||||
| full forward, backward, and update with :func:`compile`. | ||||
|  | ||||
| To start, here is the simple example without any compilation: | ||||
|  | ||||
| .. code-block:: python  | ||||
|  | ||||
|   import mlx.core as mx | ||||
|   import mlx.nn as nn | ||||
|   import mlx.optimizers as optim | ||||
|  | ||||
|   # 4 examples with 10 features each | ||||
|   x = mx.random.uniform(shape=(4, 10)) | ||||
|  | ||||
|   # 0, 1 targets | ||||
|   y = mx.array([0, 1, 0, 1]) | ||||
|  | ||||
|   # Simple linear model | ||||
|   model = nn.Linear(10, 1) | ||||
|  | ||||
|   # SGD with momentum | ||||
|   optimizer = optim.SGD(learning_rate=0.1, momentum=0.8) | ||||
|  | ||||
|   def loss_fn(model, x, y): | ||||
|       logits = model(x).squeeze() | ||||
|       return nn.losses.binary_cross_entropy(logits, y) | ||||
|  | ||||
|   loss_and_grad_fn = nn.value_and_grad(model, loss_fn) | ||||
|  | ||||
|   # Perform 10 steps of gradient descent | ||||
|   for it in range(10): | ||||
|       loss, grads = loss_and_grad_fn(model, x, y) | ||||
|       optimizer.update(model, grads) | ||||
|       mx.eval(model.parameters(), optimizer.state) | ||||
|  | ||||
| To compile the update we can put it all in a function and compile it with the | ||||
| appropriate input and output captures. Here's the same example but compiled: | ||||
|  | ||||
| .. code-block:: python  | ||||
|  | ||||
|   import mlx.core as mx | ||||
|   import mlx.nn as nn | ||||
|   import mlx.optimizers as optim | ||||
|   from functools import partial | ||||
|  | ||||
|   # 4 examples with 10 features each | ||||
|   x = mx.random.uniform(shape=(4, 10)) | ||||
|  | ||||
|   # 0, 1 targets | ||||
|   y = mx.array([0, 1, 0, 1]) | ||||
|  | ||||
|   # Simple linear model | ||||
|   model = nn.Linear(10, 1) | ||||
|  | ||||
|   # SGD with momentum | ||||
|   optimizer = optim.SGD(learning_rate=0.1, momentum=0.8) | ||||
|  | ||||
|   def loss_fn(model, x, y): | ||||
|       logits = model(x).squeeze() | ||||
|       return nn.losses.binary_cross_entropy(logits, y) | ||||
|  | ||||
|   # The state that will be captured as input and output | ||||
|   state = [model.state, optimizer.state] | ||||
|        | ||||
|   @partial(mx.compile, inputs=state, outputs=state) | ||||
|   def step(x, y): | ||||
|       loss_and_grad_fn = nn.value_and_grad(model, loss_fn) | ||||
|       loss, grads = loss_and_grad_fn(model, x, y) | ||||
|       optimizer.update(model, grads) | ||||
|       return loss | ||||
|  | ||||
|   # Perform 10 steps of gradient descent | ||||
|   for it in range(10): | ||||
|       loss = step(x, y) | ||||
|       # Evaluate the model and optimizer state | ||||
|       mx.eval(state) | ||||
|       print(loss) | ||||
|  | ||||
|  | ||||
| .. note:: | ||||
|  | ||||
|   If you are using a module which performs random sampling such as | ||||
|   :func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the | ||||
|   ``state`` captured by :func:`compile`, i.e. ``state = [model.state, | ||||
|   optimizer.state, mx.random.state]``. | ||||
|  | ||||
|  | ||||
| .. note:: | ||||
|  | ||||
|    For more examples of compiling full training graphs checkout the  `MLX | ||||
|    Examples <https://github.com/ml-explore/mlx-examples>`_ GitHub repo. | ||||
|  | ||||
| Transformations with Compile | ||||
| ---------------------------- | ||||
|  | ||||
| In MLX function transformations are composable. You can apply any function | ||||
| transformation to the output of any other function transformation. For more on | ||||
| this, see the documentation on :ref:`function transforms | ||||
| <function_transforms>`. | ||||
|  | ||||
| Compiling transformed functions works just as expected: | ||||
|  | ||||
| .. code-block:: python | ||||
|  | ||||
|   grad_fn = mx.grad(mx.exp) | ||||
|  | ||||
|   compiled_grad_fn = mx.compile(grad_fn) | ||||
|  | ||||
|   # Prints: array(2.71828, dtype=float32) | ||||
|   print(grad_fn(mx.array(1.0))) | ||||
|  | ||||
|   # Also prints: array(2.71828, dtype=float32) | ||||
|   print(compiled_grad_fn(mx.array(1.0))) | ||||
|  | ||||
| .. note:: | ||||
|  | ||||
|    In order to compile as much as possible, a transformation of a compiled | ||||
|    function will not by default be compiled. To compile the transformed | ||||
|    function simply pass it through :func:`compile`.  | ||||
|  | ||||
| You can also compile functions which themselves call compiled functions. A | ||||
| good practice is to compile the outer most function to give :func:`compile` | ||||
| the most opportunity to optimize the computation graph: | ||||
|  | ||||
| .. code-block:: python | ||||
|  | ||||
|   @mx.compile | ||||
|   def inner(x): | ||||
|       return mx.exp(-mx.abs(x)) | ||||
|  | ||||
|   def outer(x): | ||||
|       inner(inner(x)) | ||||
|  | ||||
|   # Compiling the outer function is good to do as it will likely | ||||
|   # be faster even though the inner functions are compiled | ||||
|   fun = mx.compile(outer) | ||||
| @@ -5,9 +5,12 @@ Function Transforms | ||||
|  | ||||
| .. currentmodule:: mlx.core | ||||
|  | ||||
| MLX uses composable function transformations for automatic differentiation and | ||||
| vectorization. The key idea behind composable function transformations is that | ||||
| every transformation returns a function which can be further transformed.  | ||||
| MLX uses composable function transformations for automatic differentiation, | ||||
| vectorization, and compute graph optimizations. To see the complete list of | ||||
| function transformations check-out the :ref:`API documentation <transforms>`. | ||||
|  | ||||
| The key idea behind composable function transformations is that every | ||||
| transformation returns a function which can be further transformed. | ||||
|  | ||||
| Here is a simple example: | ||||
|  | ||||
| @@ -36,10 +39,10 @@ Using :func:`grad` on the output of :func:`grad` is always ok. You keep | ||||
| getting higher order derivatives. | ||||
|  | ||||
| Any of the MLX function transformations can be composed in any order to any | ||||
| depth. To see the complete list of function transformations check-out the | ||||
| :ref:`API documentation <transforms>`. See the following sections for more | ||||
| information on :ref:`automatic differentiaion <auto diff>` and | ||||
| :ref:`automatic vectorization <vmap>`. | ||||
| depth. See the following sections for more information on :ref:`automatic | ||||
| differentiaion <auto diff>` and :ref:`automatic vectorization <vmap>`. | ||||
| For more information on :func:`compile` see the :ref:`compile documentation <compile>`. | ||||
|  | ||||
|  | ||||
| Automatic Differentiation | ||||
| ------------------------- | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun