From 62eb28261c9409af16da2d2221a5183f0c2d187f Mon Sep 17 00:00:00 2001
From: CircleCI Docs Let’s write a custom kernel that computes Every time you make a kernel, a new Metal library is created and possibly
+JIT compiled. To reduce the overhead from that, build the kernel once with
+ Note We are only required to pass the body of the Metal kernel in Only pass the body of the Metal kernel in The full function signature will be generated using: Note: Passing Note: Passing If we want to avoid this copy, Let’s convert If we want to avoid this copy, Let’s convert Now let’s use Now let’s use First we’ll implement the forward pass as a fused kernel: Since we decorated Since we decorated The backwards pass requires atomically updating Initialize all of the kernel’s outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
-
+
+
diff --git a/docs/build/html/dev/custom_metal_kernels.html b/docs/build/html/dev/custom_metal_kernels.html
index dbff63ffa..14931cf22 100644
--- a/docs/build/html/dev/custom_metal_kernels.html
+++ b/docs/build/html/dev/custom_metal_kernels.html
@@ -8,7 +8,7 @@
-
-
+
+
@@ -926,19 +926,20 @@ document.write(`
Simple Example#
exp elementwise:def exp_elementwise(a: mx.array):
- source = """
- uint elem = thread_position_in_grid.x;
- T tmp = inp[elem];
- out[elem] = metal::exp(tmp);
- """
+
source = """
+ uint elem = thread_position_in_grid.x;
+ T tmp = inp[elem];
+ out[elem] = metal::exp(tmp);
+"""
- kernel = mx.fast.metal_kernel(
- name="myexp",
- input_names=["inp"],
- output_names=["out"],
- source=source,
- )
+kernel = mx.fast.metal_kernel(
+ name="myexp",
+ input_names=["inp"],
+ output_names=["out"],
+ source=source,
+)
+
+def exp_elementwise(a: mx.array):
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
@@ -954,9 +955,13 @@ document.write(`
assert mx.allclose(b, mx.exp(a))
fast.metal_kernel() and then use it many times.source.source. The function
+signature is generated automatically.
@@ -1004,37 +1009,43 @@ All the attributes defined in Table 5.8 of the template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
grid and threadgroup are parameters to the Metal dispatchThreads function.
-This means we will launch mx.prod(grid) threads, subdivided into threadgroup size threadgroups.
-For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.verbose=True to mx.fast.metal_kernel.__call__ will print the generated code for debugging purposes.grid and threadgroup are parameters to the Metal dispatchThreads
+function. This means we will launch mx.prod(grid) threads, subdivided into
+threadgroup size threadgroups. For optimal performance, each thread group
+dimension should be less than or equal to the corresponding grid dimension.verbose=True to ast.metal_kernel.__call__() will print the
+generated code for debugging purposes.Using Shape/Strides#
-mx.fast.metal_kernel supports an argument ensure_row_contiguous which is True by default.
-This will copy the mx.array inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
-Generally this makes writing the kernel easier, since we don’t have to worry about gaps or the ordering of the dims
-when indexing.metal_kernel automatically passes a_shape, a_strides and a_ndim for each
-input array a if any are present in source.
-We can then use MLX’s built in indexing utils to fetch the right elements for each thread.myexp above to support arbitrarily strided arrays without relying on a copy from ensure_row_contiguous:def exp_elementwise(a: mx.array):
- source = """
- uint elem = thread_position_in_grid.x;
- // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
- uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
- T tmp = inp[loc];
- // Output arrays are always row contiguous
- out[elem] = metal::exp(tmp);
- """
+
fast.metal_kernel() supports an argument ensure_row_contiguous which
+is True by default. This will copy the array inputs if needed
+before the kernel is launched to ensure that the memory layout is row
+contiguous. Generally this makes writing the kernel easier, since we don’t
+have to worry about gaps or the ordering of the dims when indexing.fast.metal_kernel() automatically passes
+a_shape, a_strides and a_ndim for each input array a if any are
+present in source. We can then use MLX’s built in indexing utils to fetch
+the right elements for each thread.myexp above to support arbitrarily strided arrays without
+relying on a copy from ensure_row_contiguous:source = """
+ uint elem = thread_position_in_grid.x;
+ // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
+ uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
+ T tmp = inp[loc];
+ // Output arrays are always row contiguous
+ out[elem] = metal::exp(tmp);
+"""
- kernel = mx.fast.metal_kernel(
- name="myexp_strided",
- input_names=["inp"],
- output_names=["out"],
- source=source
- )
+kernel = mx.fast.metal_kernel(
+ name="myexp_strided",
+ input_names=["inp"],
+ output_names=["out"],
+ source=source
+)
+
+def exp_elementwise(a: mx.array):
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
@@ -1100,10 +1111,67 @@ We can then use MLX’s built in indexing utils to fetch the right elements for
return output
mx.custom_function together with mx.fast.metal_kernel
+custom_function() together with fast.metal_kernel()
to write a fast GPU kernel for both the forward and backward passes.@mx.custom_function
+
source = """
+ uint elem = thread_position_in_grid.x;
+ int H = x_shape[1];
+ int W = x_shape[2];
+ int C = x_shape[3];
+ int gH = grid_shape[1];
+ int gW = grid_shape[2];
+
+ int w_stride = C;
+ int h_stride = W * w_stride;
+ int b_stride = H * h_stride;
+
+ uint grid_idx = elem / C * 2;
+ float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
+ float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
+
+ int ix_nw = floor(ix);
+ int iy_nw = floor(iy);
+
+ int ix_ne = ix_nw + 1;
+ int iy_ne = iy_nw;
+
+ int ix_sw = ix_nw;
+ int iy_sw = iy_nw + 1;
+
+ int ix_se = ix_nw + 1;
+ int iy_se = iy_nw + 1;
+
+ T nw = (ix_se - ix) * (iy_se - iy);
+ T ne = (ix - ix_sw) * (iy_sw - iy);
+ T sw = (ix_ne - ix) * (iy - iy_ne);
+ T se = (ix - ix_nw) * (iy - iy_nw);
+
+ int batch_idx = elem / C / gH / gW * b_stride;
+ int channel_idx = elem % C;
+ int base_idx = batch_idx + channel_idx;
+
+ T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
+ T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
+ T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
+ T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
+
+ I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
+ I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
+ I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
+ I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
+
+ out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
+"""
+
+kernel = mx.fast.metal_kernel(
+ name="grid_sample",
+ input_names=["x", "grid"],
+ output_names=["out"],
+ source=source,
+)
+
+@mx.custom_function
def grid_sample(x, grid):
assert x.ndim == 4, "`x` must be 4D."
@@ -1115,61 +1183,6 @@ to write a fast GPU kernel for both the forward and backward passes.
assert D == 2, "Last dim of `grid` must be size 2."
- source = """
- uint elem = thread_position_in_grid.x;
- int H = x_shape[1];
- int W = x_shape[2];
- int C = x_shape[3];
- int gH = grid_shape[1];
- int gW = grid_shape[2];
-
- int w_stride = C;
- int h_stride = W * w_stride;
- int b_stride = H * h_stride;
-
- uint grid_idx = elem / C * 2;
- float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
- float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
-
- int ix_nw = floor(ix);
- int iy_nw = floor(iy);
-
- int ix_ne = ix_nw + 1;
- int iy_ne = iy_nw;
-
- int ix_sw = ix_nw;
- int iy_sw = iy_nw + 1;
-
- int ix_se = ix_nw + 1;
- int iy_se = iy_nw + 1;
-
- T nw = (ix_se - ix) * (iy_se - iy);
- T ne = (ix - ix_sw) * (iy_sw - iy);
- T sw = (ix_ne - ix) * (iy - iy_ne);
- T se = (ix - ix_nw) * (iy - iy_nw);
-
- int batch_idx = elem / C / gH / gW * b_stride;
- int channel_idx = elem % C;
- int base_idx = batch_idx + channel_idx;
-
- T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
- T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
- T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
- T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
-
- I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
- I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
- I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
- I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
-
- out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
- """
- kernel = mx.fast.metal_kernel(
- name="grid_sample",
- input_names=["x", "grid"],
- output_names=["out"],
- source=source,
- )
outputs = kernel(
inputs=[x, grid],
template=[("T", x.dtype)],
@@ -1191,10 +1204,10 @@ to write a fast GPU kernel for both the forward and backward passes.
Grid Sample VJP#
-grid_sample with mx.custom_function, we can now define
-its custom vjp transform so MLX can differentiate it.grid_sample with custom_function(), we can now
+define its custom vjp transform so MLX can differentiate it.x_grad/grid_grad and so
-requires a few extra mx.fast.metal_kernel features:fast.metal_kernel() features:
init_value=0@grid_sample.vjp
+
source = """
+ uint elem = thread_position_in_grid.x;
+ int H = x_shape[1];
+ int W = x_shape[2];
+ int C = x_shape[3];
+ // Pad C to the nearest larger simdgroup size multiple
+ int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
+
+ int gH = grid_shape[1];
+ int gW = grid_shape[2];
+
+ int w_stride = C;
+ int h_stride = W * w_stride;
+ int b_stride = H * h_stride;
+
+ uint grid_idx = elem / C_padded * 2;
+ float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
+ float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
+
+ int ix_nw = floor(ix);
+ int iy_nw = floor(iy);
+
+ int ix_ne = ix_nw + 1;
+ int iy_ne = iy_nw;
+
+ int ix_sw = ix_nw;
+ int iy_sw = iy_nw + 1;
+
+ int ix_se = ix_nw + 1;
+ int iy_se = iy_nw + 1;
+
+ T nw = (ix_se - ix) * (iy_se - iy);
+ T ne = (ix - ix_sw) * (iy_sw - iy);
+ T sw = (ix_ne - ix) * (iy - iy_ne);
+ T se = (ix - ix_nw) * (iy - iy_nw);
+
+ int batch_idx = elem / C_padded / gH / gW * b_stride;
+ int channel_idx = elem % C_padded;
+ int base_idx = batch_idx + channel_idx;
+
+ T gix = T(0);
+ T giy = T(0);
+ if (channel_idx < C) {
+ int cot_index = elem / C_padded * C + channel_idx;
+ T cot = cotangent[cot_index];
+ if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
+ int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
+ atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
+
+ T I_nw = x[offset];
+ gix -= I_nw * (iy_se - iy) * cot;
+ giy -= I_nw * (ix_se - ix) * cot;
+ }
+ if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
+ int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
+ atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
+
+ T I_ne = x[offset];
+ gix += I_ne * (iy_sw - iy) * cot;
+ giy -= I_ne * (ix - ix_sw) * cot;
+ }
+ if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
+ int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
+ atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
+
+ T I_sw = x[offset];
+ gix -= I_sw * (iy - iy_ne) * cot;
+ giy += I_sw * (ix_ne - ix) * cot;
+ }
+ if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
+ int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
+ atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
+
+ T I_se = x[offset];
+ gix += I_se * (iy - iy_nw) * cot;
+ giy += I_se * (ix - ix_nw) * cot;
+ }
+ }
+
+ T gix_mult = W / 2;
+ T giy_mult = H / 2;
+
+ // Reduce across each simdgroup first.
+ // This is much faster than relying purely on atomics.
+ gix = simd_sum(gix);
+ giy = simd_sum(giy);
+
+ if (thread_index_in_simdgroup == 0) {
+ atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
+ atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
+ }
+"""
+kernel = mx.fast.metal_kernel(
+ name="grid_sample_grad",
+ input_names=["x", "grid", "cotangent"],
+ output_names=["x_grad", "grid_grad"],
+ source=source,
+ atomic_outputs=True,
+)
+
+@grid_sample.vjp
def grid_sample_vjp(primals, cotangent, _):
x, grid = primals
B, _, _, C = x.shape
@@ -1218,105 +1331,6 @@ See section 6.15 of the assert D == 2, "Last dim of `grid` must be size 2."
- source = """
- uint elem = thread_position_in_grid.x;
- int H = x_shape[1];
- int W = x_shape[2];
- int C = x_shape[3];
- // Pad C to the nearest larger simdgroup size multiple
- int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
-
- int gH = grid_shape[1];
- int gW = grid_shape[2];
-
- int w_stride = C;
- int h_stride = W * w_stride;
- int b_stride = H * h_stride;
-
- uint grid_idx = elem / C_padded * 2;
- float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
- float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
-
- int ix_nw = floor(ix);
- int iy_nw = floor(iy);
-
- int ix_ne = ix_nw + 1;
- int iy_ne = iy_nw;
-
- int ix_sw = ix_nw;
- int iy_sw = iy_nw + 1;
-
- int ix_se = ix_nw + 1;
- int iy_se = iy_nw + 1;
-
- T nw = (ix_se - ix) * (iy_se - iy);
- T ne = (ix - ix_sw) * (iy_sw - iy);
- T sw = (ix_ne - ix) * (iy - iy_ne);
- T se = (ix - ix_nw) * (iy - iy_nw);
-
- int batch_idx = elem / C_padded / gH / gW * b_stride;
- int channel_idx = elem % C_padded;
- int base_idx = batch_idx + channel_idx;
-
- T gix = T(0);
- T giy = T(0);
- if (channel_idx < C) {
- int cot_index = elem / C_padded * C + channel_idx;
- T cot = cotangent[cot_index];
- if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
- int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
- atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
-
- T I_nw = x[offset];
- gix -= I_nw * (iy_se - iy) * cot;
- giy -= I_nw * (ix_se - ix) * cot;
- }
- if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
- int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
- atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
-
- T I_ne = x[offset];
- gix += I_ne * (iy_sw - iy) * cot;
- giy -= I_ne * (ix - ix_sw) * cot;
- }
- if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
- int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
- atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
-
- T I_sw = x[offset];
- gix -= I_sw * (iy - iy_ne) * cot;
- giy += I_sw * (ix_ne - ix) * cot;
- }
- if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
- int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
- atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
-
- T I_se = x[offset];
- gix += I_se * (iy - iy_nw) * cot;
- giy += I_se * (ix - ix_nw) * cot;
- }
- }
-
- T gix_mult = W / 2;
- T giy_mult = H / 2;
-
- // Reduce across each simdgroup first.
- // This is much faster than relying purely on atomics.
- gix = simd_sum(gix);
- giy = simd_sum(giy);
-
- if (thread_index_in_simdgroup == 0) {
- atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
- atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
- }
- """
- kernel = mx.fast.metal_kernel(
- name="grid_sample_grad",
- input_names=["x", "grid", "cotangent"],
- output_names=["x_grad", "grid_grad"],
- source=source,
- atomic_outputs=True,
- )
# pad the output channels to simd group size
# so that our `simd_sum`s don't overlap.
simdgroup_size = 32
diff --git a/docs/build/html/dev/extensions.html b/docs/build/html/dev/extensions.html
index 2826dc74a..eb5044b28 100644
--- a/docs/build/html/dev/extensions.html
+++ b/docs/build/html/dev/extensions.html
@@ -8,7 +8,7 @@
-
-
+
+
-
+
+
-
+
+
-
+
+
-
+
+
-
+
+
-
+
+
-
+
+
-
+
+
conda install conda-forge::mlx
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12 +and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
+pip install mlx-cuda
+My OS and Python versions are in the required range but pip still does not find
@@ -988,7 +999,7 @@ the output of uname
To build and install the MLX python library from source, first, clone MLX from +
To build and install the MLX python library from source, first, clone MLX from its GitHub repo:
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Currently, MLX must be built and installed from source.
+Currently, MLX must be built and installed from source.
Similarly to the python library, to build and install the MLX C++ library start by cloning MLX from its GitHub repo:
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
@@ -1041,8 +1052,8 @@ cmake .. &&
directory as the executable statically linked to libmlx.a or the
preprocessor constant METAL_PATH should be defined at build time and it
should point to the path to the built metal library.
-
-Build Options#
+
+Build Options#
@@ -1125,8 +1136,40 @@ application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists across reboots.
+
+Linux#
+To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
+For example on Ubuntu, run the following:
+apt-get update -y
+apt-get install libblas-dev liblapack-dev liblapacke-dev -y
+
+
+From here follow the instructions to install either the Python or C++ APIs.
+
-Troubleshooting#
+CUDA#
+To build from source on Linux with CUDA, install the BLAS and LAPACK headers
+and the CUDA toolkit. For example on Ubuntu, run the following:
+wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
+dpkg -i cuda-keyring_1.1-1_all.deb
+apt-get update -y
+apt-get -y install cuda-toolkit-12-9
+apt-get install libblas-dev liblapack-dev liblapacke-dev -y
+
+
+When building either the Python or C++ APIs make sure to pass the cmake flag
+MLX_BUILD_CUDA=ON. For example, to build the Python API run:
+CMAKE_BUILD_PARALLEL_LEVEL=8 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
+
+
+To build the C++ package run:
+mkdir -p build && cd build
+cmake .. -DMLX_BUILD_CUDA=ON && make -j
+
+
+
+
+Troubleshooting#
Metal not found#
You see the following error when you try to build:
@@ -1214,6 +1257,7 @@ wipe your build cache with
- Python Installation
@@ -1224,7 +1268,9 @@ wipe your build cache with Binary Size Minimization
-Troubleshooting