Fix unintuitive metal kernel caching (#2242)

* Fix unintuitive metal kernel caching

* alternative solution
This commit is contained in:
Awni Hannun 2025-06-06 20:08:15 -07:00 committed by GitHub
parent 2e8cf0b450
commit 1ca616844b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 713 additions and 593 deletions

View File

@ -8,23 +8,26 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs.
Simple Example Simple Example
-------------- --------------
.. currentmodule:: mlx.core
Let's write a custom kernel that computes ``exp`` elementwise: Let's write a custom kernel that computes ``exp`` elementwise:
.. code-block:: python .. code-block:: python
def exp_elementwise(a: mx.array): source = """
source = """ uint elem = thread_position_in_grid.x;
uint elem = thread_position_in_grid.x; T tmp = inp[elem];
T tmp = inp[elem]; out[elem] = metal::exp(tmp);
out[elem] = metal::exp(tmp); """
"""
kernel = mx.fast.metal_kernel( kernel = mx.fast.metal_kernel(
name="myexp", name="myexp",
input_names=["inp"], input_names=["inp"],
output_names=["out"], output_names=["out"],
source=source, source=source,
) )
def exp_elementwise(a: mx.array):
outputs = kernel( outputs = kernel(
inputs=[a], inputs=[a],
template=[("T", mx.float32)], template=[("T", mx.float32)],
@ -39,8 +42,13 @@ Let's write a custom kernel that computes ``exp`` elementwise:
b = exp_elementwise(a) b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a)) assert mx.allclose(b, mx.exp(a))
Every time you make a kernel, a new Metal library is created and possibly
JIT compiled. To reduce the overhead from that, build the kernel once with
:func:`fast.metal_kernel` and then use it many times.
.. note:: .. note::
We are only required to pass the body of the Metal kernel in ``source``. Only pass the body of the Metal kernel in ``source``. The function
signature is generated automatically.
The full function signature will be generated using: The full function signature will be generated using:
@ -78,44 +86,51 @@ Putting this all together, the generated function signature for ``myexp`` is as
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>; template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function. Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups. <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension. function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
``threadgroup`` size threadgroups. For optimal performance, each thread group
dimension should be less than or equal to the corresponding grid dimension.
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes. Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
generated code for debugging purposes.
Using Shape/Strides Using Shape/Strides
------------------- -------------------
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default. :func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous. is ``True`` by default. This will copy the array inputs if needed
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims before the kernel is launched to ensure that the memory layout is row
when indexing. contiguous. Generally this makes writing the kernel easier, since we don't
have to worry about gaps or the ordering of the dims when indexing.
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
input array ``a`` if any are present in ``source``. ``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
We can then use MLX's built in indexing utils to fetch the right elements for each thread. present in ``source``. We can then use MLX's built in indexing utils to fetch
the right elements for each thread.
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``: Let's convert ``myexp`` above to support arbitrarily strided arrays without
relying on a copy from ``ensure_row_contiguous``:
.. code-block:: python .. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
// Output arrays are always row contiguous
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source
)
def exp_elementwise(a: mx.array): def exp_elementwise(a: mx.array):
source = """
uint elem = thread_position_in_grid.x;
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
// Output arrays are always row contiguous
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source
)
outputs = kernel( outputs = kernel(
inputs=[a], inputs=[a],
template=[("T", mx.float32)], template=[("T", mx.float32)],
@ -142,137 +157,139 @@ We'll start with the following MLX implementation using standard ops:
.. code-block:: python .. code-block:: python
def grid_sample_ref(x, grid): def grid_sample_ref(x, grid):
N, H_in, W_in, _ = x.shape N, H_in, W_in, _ = x.shape
ix = ((grid[..., 0] + 1) * W_in - 1) / 2 ix = ((grid[..., 0] + 1) * W_in - 1) / 2
iy = ((grid[..., 1] + 1) * H_in - 1) / 2 iy = ((grid[..., 1] + 1) * H_in - 1) / 2
ix_nw = mx.floor(ix).astype(mx.int32) ix_nw = mx.floor(ix).astype(mx.int32)
iy_nw = mx.floor(iy).astype(mx.int32) iy_nw = mx.floor(iy).astype(mx.int32)
ix_ne = ix_nw + 1 ix_ne = ix_nw + 1
iy_ne = iy_nw iy_ne = iy_nw
ix_sw = ix_nw ix_sw = ix_nw
iy_sw = iy_nw + 1 iy_sw = iy_nw + 1
ix_se = ix_nw + 1 ix_se = ix_nw + 1
iy_se = iy_nw + 1 iy_se = iy_nw + 1
nw = (ix_se - ix) * (iy_se - iy) nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy) ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne) sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw) se = (ix - ix_nw) * (iy - iy_nw)
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :] I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :] I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :] I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :] I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1) mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1) mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1) mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1) mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
I_nw *= mask_nw[..., None] I_nw *= mask_nw[..., None]
I_ne *= mask_ne[..., None] I_ne *= mask_ne[..., None]
I_sw *= mask_sw[..., None] I_sw *= mask_sw[..., None]
I_se *= mask_se[..., None] I_se *= mask_se[..., None]
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
return output return output
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel`` Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
to write a fast GPU kernel for both the forward and backward passes. to write a fast GPU kernel for both the forward and backward passes.
First we'll implement the forward pass as a fused kernel: First we'll implement the forward pass as a fused kernel:
.. code-block:: python .. code-block:: python
@mx.custom_function source = """
def grid_sample(x, grid): uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
assert x.ndim == 4, "`x` must be 4D." int w_stride = C;
assert grid.ndim == 4, "`grid` must be 4D." int h_stride = W * w_stride;
int b_stride = H * h_stride;
B, _, _, C = x.shape uint grid_idx = elem / C * 2;
_, gN, gM, D = grid.shape float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
out_shape = (B, gN, gM, C) float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
assert D == 2, "Last dim of `grid` must be size 2." int ix_nw = floor(ix);
int iy_nw = floor(iy);
source = """ int ix_ne = ix_nw + 1;
uint elem = thread_position_in_grid.x; int iy_ne = iy_nw;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
int w_stride = C; int ix_sw = ix_nw;
int h_stride = W * w_stride; int iy_sw = iy_nw + 1;
int b_stride = H * h_stride;
uint grid_idx = elem / C * 2; int ix_se = ix_nw + 1;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2; int iy_se = iy_nw + 1;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_nw = floor(ix); T nw = (ix_se - ix) * (iy_se - iy);
int iy_nw = floor(iy); T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int ix_ne = ix_nw + 1; int batch_idx = elem / C / gH / gW * b_stride;
int iy_ne = iy_nw; int channel_idx = elem % C;
int base_idx = batch_idx + channel_idx;
int ix_sw = ix_nw; T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
int iy_sw = iy_nw + 1; T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
int ix_se = ix_nw + 1; I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
int iy_se = iy_nw + 1; I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
T nw = (ix_se - ix) * (iy_se - iy); out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
T ne = (ix - ix_sw) * (iy_sw - iy); """
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int batch_idx = elem / C / gH / gW * b_stride; kernel = mx.fast.metal_kernel(
int channel_idx = elem % C; name="grid_sample",
int base_idx = batch_idx + channel_idx; input_names=["x", "grid"],
output_names=["out"],
source=source,
)
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride]; @mx.custom_function
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride]; def grid_sample(x, grid):
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0; assert x.ndim == 4, "`x` must be 4D."
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0; assert grid.ndim == 4, "`grid` must be 4D."
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se; B, _, _, C = x.shape
""" _, gN, gM, D = grid.shape
kernel = mx.fast.metal_kernel( out_shape = (B, gN, gM, C)
name="grid_sample",
input_names=["x", "grid"], assert D == 2, "Last dim of `grid` must be size 2."
output_names=["out"],
source=source, outputs = kernel(
) inputs=[x, grid],
outputs = kernel( template=[("T", x.dtype)],
inputs=[x, grid], output_shapes=[out_shape],
template=[("T", x.dtype)], output_dtypes=[x.dtype],
output_shapes=[out_shape], grid=(np.prod(out_shape), 1, 1),
output_dtypes=[x.dtype], threadgroup=(256, 1, 1),
grid=(np.prod(out_shape), 1, 1), )
threadgroup=(256, 1, 1), return outputs[0]
)
return outputs[0]
For a reasonably sized input such as: For a reasonably sized input such as:
.. code-block:: python .. code-block:: python
x.shape = (8, 1024, 1024, 64) x.shape = (8, 1024, 1024, 64)
grid.shape = (8, 256, 256, 2) grid.shape = (8, 256, 256, 2)
On an M1 Max, we see a big performance improvement: On an M1 Max, we see a big performance improvement:
@ -281,11 +298,11 @@ On an M1 Max, we see a big performance improvement:
Grid Sample VJP Grid Sample VJP
--------------- ---------------
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
its custom vjp transform so MLX can differentiate it. define its custom vjp transform so MLX can differentiate it.
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
requires a few extra ``mx.fast.metal_kernel`` features: requires a few extra :func:`fast.metal_kernel` features:
* ``init_value=0`` * ``init_value=0``
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel. Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
@ -299,128 +316,129 @@ We can then implement the backwards pass as follows:
.. code-block:: python .. code-block:: python
@grid_sample.vjp source = """
def grid_sample_vjp(primals, cotangent, _): uint elem = thread_position_in_grid.x;
x, grid = primals int H = x_shape[1];
B, _, _, C = x.shape int W = x_shape[2];
_, gN, gM, D = grid.shape int C = x_shape[3];
// Pad C to the nearest larger simdgroup size multiple
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
assert D == 2, "Last dim of `grid` must be size 2." int gH = grid_shape[1];
int gW = grid_shape[2];
source = """ int w_stride = C;
uint elem = thread_position_in_grid.x; int h_stride = W * w_stride;
int H = x_shape[1]; int b_stride = H * h_stride;
int W = x_shape[2];
int C = x_shape[3];
// Pad C to the nearest larger simdgroup size multiple
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
int gH = grid_shape[1]; uint grid_idx = elem / C_padded * 2;
int gW = grid_shape[2]; float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int w_stride = C; int ix_nw = floor(ix);
int h_stride = W * w_stride; int iy_nw = floor(iy);
int b_stride = H * h_stride;
uint grid_idx = elem / C_padded * 2; int ix_ne = ix_nw + 1;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2; int iy_ne = iy_nw;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_nw = floor(ix); int ix_sw = ix_nw;
int iy_nw = floor(iy); int iy_sw = iy_nw + 1;
int ix_ne = ix_nw + 1; int ix_se = ix_nw + 1;
int iy_ne = iy_nw; int iy_se = iy_nw + 1;
int ix_sw = ix_nw; T nw = (ix_se - ix) * (iy_se - iy);
int iy_sw = iy_nw + 1; T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int ix_se = ix_nw + 1; int batch_idx = elem / C_padded / gH / gW * b_stride;
int iy_se = iy_nw + 1; int channel_idx = elem % C_padded;
int base_idx = batch_idx + channel_idx;
T nw = (ix_se - ix) * (iy_se - iy); T gix = T(0);
T ne = (ix - ix_sw) * (iy_sw - iy); T giy = T(0);
T sw = (ix_ne - ix) * (iy - iy_ne); if (channel_idx < C) {
T se = (ix - ix_nw) * (iy - iy_nw); int cot_index = elem / C_padded * C + channel_idx;
T cot = cotangent[cot_index];
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
int batch_idx = elem / C_padded / gH / gW * b_stride; T I_nw = x[offset];
int channel_idx = elem % C_padded; gix -= I_nw * (iy_se - iy) * cot;
int base_idx = batch_idx + channel_idx; giy -= I_nw * (ix_se - ix) * cot;
}
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
T gix = T(0); T I_ne = x[offset];
T giy = T(0); gix += I_ne * (iy_sw - iy) * cot;
if (channel_idx < C) { giy -= I_ne * (ix - ix_sw) * cot;
int cot_index = elem / C_padded * C + channel_idx; }
T cot = cotangent[cot_index]; if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) { int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride; atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
T I_nw = x[offset]; T I_sw = x[offset];
gix -= I_nw * (iy_se - iy) * cot; gix -= I_sw * (iy - iy_ne) * cot;
giy -= I_nw * (ix_se - ix) * cot; giy += I_sw * (ix_ne - ix) * cot;
} }
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) { if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride; int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed); atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
T I_ne = x[offset]; T I_se = x[offset];
gix += I_ne * (iy_sw - iy) * cot; gix += I_se * (iy - iy_nw) * cot;
giy -= I_ne * (ix - ix_sw) * cot; giy += I_se * (ix - ix_nw) * cot;
} }
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) { }
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
T I_sw = x[offset]; T gix_mult = W / 2;
gix -= I_sw * (iy - iy_ne) * cot; T giy_mult = H / 2;
giy += I_sw * (ix_ne - ix) * cot;
}
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
T I_se = x[offset]; // Reduce across each simdgroup first.
gix += I_se * (iy - iy_nw) * cot; // This is much faster than relying purely on atomics.
giy += I_se * (ix - ix_nw) * cot; gix = simd_sum(gix);
} giy = simd_sum(giy);
}
T gix_mult = W / 2; if (thread_index_in_simdgroup == 0) {
T giy_mult = H / 2; atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
}
"""
kernel = mx.fast.metal_kernel(
name="grid_sample_grad",
input_names=["x", "grid", "cotangent"],
output_names=["x_grad", "grid_grad"],
source=source,
atomic_outputs=True,
)
// Reduce across each simdgroup first. @grid_sample.vjp
// This is much faster than relying purely on atomics. def grid_sample_vjp(primals, cotangent, _):
gix = simd_sum(gix); x, grid = primals
giy = simd_sum(giy); B, _, _, C = x.shape
_, gN, gM, D = grid.shape
if (thread_index_in_simdgroup == 0) { assert D == 2, "Last dim of `grid` must be size 2."
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed); # pad the output channels to simd group size
} # so that our `simd_sum`s don't overlap.
""" simdgroup_size = 32
kernel = mx.fast.metal_kernel( C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
name="grid_sample_grad", grid_size = B * gN * gM * C_padded
input_names=["x", "grid", "cotangent"], outputs = kernel(
output_names=["x_grad", "grid_grad"], inputs=[x, grid, cotangent],
source=source, template=[("T", x.dtype)],
atomic_outputs=True, output_shapes=[x.shape, grid.shape],
) output_dtypes=[x.dtype, x.dtype],
# pad the output channels to simd group size grid=(grid_size, 1, 1),
# so that our `simd_sum`s don't overlap. threadgroup=(256, 1, 1),
simdgroup_size = 32 init_value=0,
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size )
grid_size = B * gN * gM * C_padded return outputs[0], outputs[1]
outputs = kernel(
inputs=[x, grid, cotangent],
template=[("T", x.dtype)],
output_shapes=[x.shape, grid.shape],
output_dtypes=[x.dtype, x.dtype],
grid=(grid_size, 1, 1),
threadgroup=(256, 1, 1),
init_value=0,
)
return outputs[0], outputs[1]
There's an even larger speed up for the vjp: There's an even larger speed up for the vjp:

View File

@ -397,11 +397,11 @@ below.
std::ostringstream kname; std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out); kname << "axpby_" << "general_" << type_to_name(out);
// Make sure the metal library is available // Load the metal library
d.register_library("mlx_ext"); auto lib = d.get_library("mlx_ext");
// Make a kernel from this metal library // Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext"); auto kernel = d.get_kernel(kname.str(), lib);
// Prepare to encode kernel // Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);

View File

@ -172,11 +172,11 @@ void Axpby::eval_gpu(
kname << (contiguous_kernel ? "contiguous_" : "general_"); kname << (contiguous_kernel ? "contiguous_" : "general_");
kname << type_to_name(out); kname << type_to_name(out);
// Make sure the metal library is available // Load the metal library
d.register_library("mlx_ext"); auto lib = d.get_library("mlx_ext");
// Make a kernel from this metal library // Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext"); auto kernel = d.get_kernel(kname.str(), lib);
// Prepare to encode kernel // Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);

View File

@ -677,7 +677,7 @@ void depthwise_conv_2D_gpu(
std::string hash_name = kname.str(); std::string hash_name = kname.str();
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); auto kernel = d.get_kernel(base_name, hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);

View File

@ -1,12 +1,326 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <iostream>
#include <regex>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/fast.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
#include "mlx/utils.h"
namespace mlx::core::fast { namespace mlx::core::fast {
struct CustomKernelCache {
std::unordered_map<std::string, std::string> libraries;
};
static CustomKernelCache& cache() {
static CustomKernelCache cache_;
return cache_;
};
std::string write_signature(
std::string func_name,
const std::string& header,
const std::string& source,
const std::vector<std::string>& input_names,
const std::vector<array>& inputs,
const std::vector<std::string>& output_names,
const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
const std::vector<std::string>& attributes,
const std::vector<CustomKernelShapeInfo>& shape_infos,
bool atomic_outputs) {
std::string kernel_source;
kernel_source.reserve(header.size() + source.size() + 16384);
kernel_source += header;
// Auto-generate a function signature based on `template_args`
// and the dtype/shape of the arrays passed as `inputs`.
if (!template_args.empty()) {
kernel_source += "template <";
int i = 0;
for (const auto& [name, arg] : template_args) {
std::string param_type;
if (std::holds_alternative<int>(arg)) {
param_type = "int";
} else if (std::holds_alternative<bool>(arg)) {
param_type = "bool";
} else if (std::holds_alternative<Dtype>(arg)) {
param_type = "typename";
}
if (i > 0) {
kernel_source += ", ";
}
kernel_source += param_type;
kernel_source += " ";
kernel_source += name;
i++;
}
kernel_source += ">\n";
}
kernel_source += "[[kernel]] void ";
kernel_source += func_name;
kernel_source += "(\n";
int index = 0;
constexpr int max_constant_array_size = 8;
// Add inputs
for (int i = 0; i < inputs.size(); ++i) {
const auto& name = input_names[i];
const auto& arr = inputs[i];
auto dtype = get_type_string(arr.dtype());
std::string location =
arr.size() < max_constant_array_size ? "constant" : "device";
std::string ref = arr.ndim() == 0 ? "&" : "*";
kernel_source += " const ";
kernel_source += location;
kernel_source += " ";
kernel_source += dtype;
kernel_source += ref;
kernel_source += " ";
kernel_source += name;
kernel_source += " [[buffer(";
kernel_source += std::to_string(index);
kernel_source += ")]],\n";
index++;
// Add input shape, strides and ndim if present in the source
if (arr.ndim() > 0) {
if (shape_infos[i].shape) {
kernel_source +=
(" const constant int* " + name + "_shape [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
if (shape_infos[i].strides) {
kernel_source +=
(" const constant int64_t* " + name + "_strides [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
if (shape_infos[i].ndim) {
kernel_source +=
(" const constant int& " + name + "_ndim [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
}
}
// Add outputs
for (int i = 0; i < output_names.size(); ++i) {
const auto& name = output_names[i];
const auto& dtype = output_dtypes[i];
kernel_source += " device ";
auto type_string = get_type_string(dtype);
if (atomic_outputs) {
kernel_source += "atomic<";
}
kernel_source += type_string;
if (atomic_outputs) {
kernel_source += ">";
}
kernel_source += "* ";
kernel_source += name;
kernel_source += " [[buffer(";
kernel_source += std::to_string(index);
kernel_source += ")]]";
if (index < inputs.size() + output_names.size() - 1 ||
attributes.size() > 0) {
kernel_source += ",\n";
} else {
kernel_source += ") {\n";
}
index++;
}
index = 0;
for (const auto& attr : attributes) {
kernel_source += attr;
if (index < attributes.size() - 1) {
kernel_source += ",\n";
} else {
kernel_source += ") {\n";
}
index++;
}
kernel_source += source;
kernel_source += "\n}\n";
return kernel_source;
}
std::string write_template(
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
std::ostringstream template_def;
template_def << "<";
int i = 0;
for (const auto& [name, arg] : template_args) {
if (i > 0) {
template_def << ", ";
}
if (std::holds_alternative<int>(arg)) {
template_def << std::get<int>(arg);
} else if (std::holds_alternative<bool>(arg)) {
template_def << std::get<bool>(arg);
} else if (std::holds_alternative<Dtype>(arg)) {
template_def << get_type_string(std::get<Dtype>(arg));
}
i++;
}
template_def << ">";
return template_def.str();
}
MetalKernelFunction metal_kernel(
const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header /* = "" */,
bool ensure_row_contiguous /* = true */,
bool atomic_outputs /* = false */) {
if (output_names.empty()) {
throw std::invalid_argument(
"[metal_kernel] Must specify at least one output.");
}
std::vector<CustomKernelShapeInfo> shape_infos;
for (auto& n : input_names) {
CustomKernelShapeInfo shape_info;
shape_info.shape = source.find(n + "_shape") != std::string::npos;
shape_info.strides = source.find(n + "_strides") != std::string::npos;
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
shape_infos.push_back(shape_info);
}
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
{"dispatch_quadgroups_per_threadgroup", "uint"},
{"dispatch_simdgroups_per_threadgroup", "uint"},
{"dispatch_threads_per_threadgroup", "uint3"},
{"grid_origin", "uint3"},
{"grid_size", "uint3"},
{"quadgroup_index_in_threadgroup", "uint"},
{"quadgroups_per_threadgroup", "uint"},
{"simdgroup_index_in_threadgroup", "uint"},
{"simdgroups_per_threadgroup", "uint"},
{"thread_execution_width", "uint"},
{"thread_index_in_quadgroup", "uint"},
{"thread_index_in_simdgroup", "uint"},
{"thread_index_in_threadgroup", "uint"},
{"thread_position_in_grid", "uint3"},
{"thread_position_in_threadgroup", "uint3"},
{"threadgroup_position_in_grid", "uint3"},
{"threadgroups_per_grid", "uint3"},
{"threads_per_grid", "uint3"},
{"threads_per_simdgroup", "uint"},
{"threads_per_threadgroup", "uint3"},
};
std::vector<std::string> attributes;
for (const auto& [attr, dtype] : metal_attributes) {
if (source.find(attr) != std::string::npos) {
attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]");
}
}
return [=,
shape_infos = std::move(shape_infos),
attributes = std::move(attributes)](
const std::vector<array>& inputs,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::vector<std::pair<std::string, TemplateArg>>&
template_args = {},
std::optional<float> init_value = std::nullopt,
bool verbose = false,
StreamOrDevice s_ = {}) {
if (inputs.size() != input_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `inputs` to have size "
<< input_names.size() << " but got size " << inputs.size() << "."
<< std::endl;
throw std::invalid_argument(msg.str());
}
if (output_shapes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `output_shapes` to have size "
<< output_names.size() << " but got size " << output_shapes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
if (output_dtypes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `output_dtypes` to have size "
<< output_names.size() << " but got size " << output_dtypes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
auto s = to_stream(s_);
if (s.device != Device::gpu) {
throw std::invalid_argument("[metal_kernel] Only supports the GPU.");
}
std::string kernel_name = "custom_kernel_" + name;
std::string template_def = "";
if (!template_args.empty()) {
std::regex disallowed_chars("\\<|\\>|(, )");
template_def = write_template(template_args);
auto template_hash =
std::regex_replace(template_def, disallowed_chars, "_");
template_hash.pop_back();
kernel_name += "_";
kernel_name += template_hash;
}
std::string kernel_source = write_signature(
kernel_name,
header,
source,
input_names,
inputs,
output_names,
output_dtypes,
template_args,
attributes,
shape_infos,
atomic_outputs);
if (!template_args.empty()) {
template_def = kernel_name + template_def;
kernel_source += "\ntemplate [[host_name(\"";
kernel_source += kernel_name;
kernel_source += "\")]] [[kernel]] decltype(";
kernel_source += template_def;
kernel_source += ") ";
kernel_source += template_def;
kernel_source += ";\n";
}
if (verbose) {
std::cout << "Generated source code for `" << name << "`:" << std::endl
<< "```" << std::endl
<< kernel_source << std::endl
<< "```" << std::endl;
}
return array::make_arrays(
std::move(output_shapes),
std::move(output_dtypes),
std::make_shared<CustomKernel>(
s,
std::move(kernel_name),
std::move(kernel_source),
grid,
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value),
std::move(inputs));
};
}
void CustomKernel::eval_gpu( void CustomKernel::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
@ -39,9 +353,23 @@ void CustomKernel::eval_gpu(
} }
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
const auto& lib_name = name_;
auto lib = {
d.get_library(lib_name, [this] { return metal::utils() + source_; }); // Clear kernels from the device library cache if needed
auto& kernel_cache = cache();
if (auto it = kernel_cache.libraries.find(name_);
it != kernel_cache.libraries.end()) {
if (it->second != source_) {
auto& d = metal::device(s.device);
d.clear_library(name_);
it->second = source_;
}
} else {
kernel_cache.libraries.emplace(name_, source_);
}
}
auto lib = d.get_library(name_, [this] { return metal::utils() + source_; });
auto kernel = d.get_kernel(name_, lib); auto kernel = d.get_kernel(name_, lib);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
@ -73,6 +401,16 @@ void CustomKernel::eval_gpu(
} }
const auto [tx, ty, tz] = threadgroup_; const auto [tx, ty, tz] = threadgroup_;
auto tg_size = tx * ty * tz;
auto max_tg_size = kernel->maxTotalThreadsPerThreadgroup();
if (tg_size > max_tg_size) {
std::ostringstream msg;
msg << "Thread group size (" << tg_size << ") is greater than "
<< " the maximum allowed threads per threadgroup (" << max_tg_size
<< ").";
throw std::invalid_argument(msg.str());
}
const auto [gx, gy, gz] = grid_; const auto [gx, gy, gz] = grid_;
MTL::Size group_dims = MTL::Size group_dims =
MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz)); MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));

View File

@ -295,7 +295,7 @@ void CommandEncoder::barrier() {
Device::Device() { Device::Device() {
auto pool = new_scoped_memory_pool(); auto pool = new_scoped_memory_pool();
device_ = load_device(); device_ = load_device();
library_map_ = {{"mlx", load_default_library(device_)}}; default_library_ = load_default_library(device_);
arch_ = std::string(device_->architecture()->name()->utf8String()); arch_ = std::string(device_->architecture()->name()->utf8String());
auto arch = arch_.back(); auto arch = arch_.back();
switch (arch) { switch (arch) {
@ -326,11 +326,11 @@ Device::Device() {
Device::~Device() { Device::~Device() {
auto pool = new_scoped_memory_pool(); auto pool = new_scoped_memory_pool();
for (auto& k : kernel_map_) { for (auto& [l, kernel_map] : library_kernels_) {
k.second->release(); l->release();
} for (auto& [_, k] : kernel_map) {
for (auto& l : library_map_) { k->release();
l.second->release(); }
} }
stream_map_.clear(); stream_map_.clear();
device_->release(); device_->release();
@ -474,13 +474,24 @@ CommandEncoder& Device::get_command_encoder(int index) {
return *stream.encoder; return *stream.encoder;
} }
void Device::register_library( MTL::Library* Device::get_library(
const std::string& lib_name, const std::string& name,
const std::string& lib_path) { const std::string& path /* = "" */) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) { {
auto new_lib = load_library(device_, lib_name, lib_path.c_str()); std::shared_lock rlock(library_mtx_);
library_map_.insert({lib_name, new_lib}); if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
} }
std::unique_lock wlock(library_mtx_);
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
auto new_lib = load_library(device_, name, path.c_str());
library_map_.insert({name, new_lib});
return new_lib;
} }
MTL::Library* Device::build_library_(const std::string& source_string) { MTL::Library* Device::build_library_(const std::string& source_string) {
@ -649,6 +660,19 @@ MTL::Library* Device::get_library(
return mtl_lib; return mtl_lib;
} }
void Device::clear_library(const std::string& name) {
std::unique_lock wlock(library_mtx_);
if (auto it = library_map_.find(name); it != library_map_.end()) {
auto kernel_map_it = library_kernels_.find(it->second);
for (auto& [_, kernel] : kernel_map_it->second) {
kernel->release();
}
library_kernels_.erase(kernel_map_it);
it->second->release();
library_map_.erase(it);
}
}
MTL::LinkedFunctions* Device::get_linked_functions_( MTL::LinkedFunctions* Device::get_linked_functions_(
const std::vector<MTL::Function*>& funcs) { const std::vector<MTL::Function*>& funcs) {
if (funcs.empty()) { if (funcs.empty()) {
@ -679,6 +703,7 @@ MTL::ComputePipelineState* Device::get_kernel_(
std::unique_lock wlock(kernel_mtx_); std::unique_lock wlock(kernel_mtx_);
// Try loading again to avoid loading twice // Try loading again to avoid loading twice
auto& kernel_map_ = library_kernels_[mtl_lib];
if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) { if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) {
return it->second; return it->second;
} }
@ -713,6 +738,7 @@ MTL::ComputePipelineState* Device::get_kernel(
std::shared_lock lock(kernel_mtx_); std::shared_lock lock(kernel_mtx_);
// Look for cached kernel // Look for cached kernel
auto& kernel_map_ = library_kernels_[mtl_lib];
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) { if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second; return it->second;
} }
@ -722,23 +748,11 @@ MTL::ComputePipelineState* Device::get_kernel(
MTL::ComputePipelineState* Device::get_kernel( MTL::ComputePipelineState* Device::get_kernel(
const std::string& base_name, const std::string& base_name,
const std::string& lib_name /* = "mlx" */,
const std::string& hash_name /* = "" */, const std::string& hash_name /* = "" */,
const MTLFCList& func_consts /* = {} */, const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) { const std::vector<MTL::Function*>& linked_functions /* = {} */) {
const auto& kname = hash_name.size() == 0 ? base_name : hash_name; return get_kernel(
{ base_name, default_library_, hash_name, func_consts, linked_functions);
// Multiple readers allowed
std::shared_lock lock(kernel_mtx_);
// Look for cached kernel
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
}
// Search for cached metal lib
MTL::Library* mtl_lib = get_library_(lib_name);
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
} }
void Device::set_residency_set(const MTL::ResidencySet* residency_set) { void Device::set_residency_set(const MTL::ResidencySet* residency_set) {

View File

@ -187,14 +187,16 @@ class Device {
CommandEncoder& get_command_encoder(int index); CommandEncoder& get_command_encoder(int index);
void end_encoding(int index); void end_encoding(int index);
void register_library( MTL::Library* get_library(
const std::string& lib_name, const std::string& name,
const std::string& lib_path = ""); const std::string& path = "");
MTL::Library* get_library( MTL::Library* get_library(
const std::string& name, const std::string& name,
const std::function<std::string(void)>& builder); const std::function<std::string(void)>& builder);
void clear_library(const std::string& name);
MTL::ComputePipelineState* get_kernel( MTL::ComputePipelineState* get_kernel(
const std::string& base_name, const std::string& base_name,
MTL::Library* mtl_lib, MTL::Library* mtl_lib,
@ -204,7 +206,6 @@ class Device {
MTL::ComputePipelineState* get_kernel( MTL::ComputePipelineState* get_kernel(
const std::string& base_name, const std::string& base_name,
const std::string& lib_name = "mlx",
const std::string& hash_name = "", const std::string& hash_name = "",
const MTLFCList& func_consts = {}, const MTLFCList& func_consts = {},
const std::vector<MTL::Function*>& linked_functions = {}); const std::vector<MTL::Function*>& linked_functions = {});
@ -258,10 +259,13 @@ class Device {
std::unordered_map<int32_t, DeviceStream> stream_map_; std::unordered_map<int32_t, DeviceStream> stream_map_;
std::shared_mutex kernel_mtx_; std::shared_mutex kernel_mtx_;
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
std::shared_mutex library_mtx_; std::shared_mutex library_mtx_;
std::unordered_map<std::string, MTL::Library*> library_map_; std::unordered_map<std::string, MTL::Library*> library_map_;
MTL::Library* default_library_;
std::unordered_map<
MTL::Library*,
std::unordered_map<std::string, MTL::ComputePipelineState*>>
library_kernels_;
const MTL::ResidencySet* residency_set_{nullptr}; const MTL::ResidencySet* residency_set_{nullptr};
std::string arch_; std::string arch_;
int max_ops_per_buffer_; int max_ops_per_buffer_;

View File

@ -146,7 +146,7 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
int, int,
int, int,
int) { int) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); return d.get_kernel(kernel_name, hash_name, func_consts);
} }
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel( MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
@ -207,7 +207,7 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
int, int,
int, int,
bool) { bool) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); return d.get_kernel(kernel_name, hash_name, func_consts);
} }
MTL::ComputePipelineState* get_gemv_masked_kernel( MTL::ComputePipelineState* get_gemv_masked_kernel(
@ -259,7 +259,7 @@ MTL::ComputePipelineState* get_fft_kernel(
const std::string& hash_name, const std::string& hash_name,
const metal::MTLFCList& func_consts, const metal::MTLFCList& func_consts,
const std::string&) { const std::string&) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); return d.get_kernel(kernel_name, hash_name, func_consts);
} }
MTL::ComputePipelineState* get_quantized_kernel( MTL::ComputePipelineState* get_quantized_kernel(
@ -283,7 +283,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
int, int,
int, int,
bool) { bool) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); return d.get_kernel(kernel_name, hash_name, func_consts);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -172,7 +172,7 @@ void RMSNormVJP::eval_gpu(
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
{ {
auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts); auto kernel = d.get_kernel(op_name, hash_name, func_consts);
MTL::Size grid_dims, group_dims; MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) { if (axis_size <= looped_limit) {
@ -395,7 +395,7 @@ void LayerNormVJP::eval_gpu(
}; };
{ {
auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts); auto kernel = d.get_kernel(op_name, hash_name, func_consts);
MTL::Size grid_dims, group_dims; MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) { if (axis_size <= looped_limit) {

View File

@ -73,7 +73,7 @@ void sdpa_full_self_attention_metal(
std::string hash_name = kname.str(); std::string hash_name = kname.str();
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); auto kernel = d.get_kernel(base_name, hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
const int NQ = (qL + bq - 1) / bq; const int NQ = (qL + bq - 1) / bq;
@ -180,7 +180,7 @@ void sdpa_vector(
// Get the kernel // Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts); auto kernel = d.get_kernel(kname, hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Set its arguments // Set its arguments
@ -281,7 +281,7 @@ void sdpa_vector_2pass(
// Get the kernel // Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts); auto kernel = d.get_kernel(kname, hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);

View File

@ -2,6 +2,7 @@
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/distributed/primitives.h" #include "mlx/distributed/primitives.h"
#include "mlx/fast.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
#define NO_GPU_MULTI(func) \ #define NO_GPU_MULTI(func) \
@ -155,6 +156,18 @@ NO_GPU_USE_FALLBACK(RoPE)
NO_GPU(ScaledDotProductAttention) NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize) NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel) NO_GPU_MULTI(CustomKernel)
MetalKernelFunction metal_kernel(
const std::string&,
const std::vector<std::string>&,
const std::vector<std::string>&,
const std::string&,
const std::string&,
bool ensure_row_contiguous,
bool atomic_outputs) {
throw std::runtime_error("[metal_kernel] No GPU back-end.");
}
} // namespace fast } // namespace fast
namespace distributed { namespace distributed {

View File

@ -1,10 +1,7 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cassert> #include <cassert>
#include <iostream>
#include <numeric> #include <numeric>
#include <regex>
#include "mlx/backend/common/compiled.h"
#include "mlx/fast.h" #include "mlx/fast.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
#include "mlx/ops.h" #include "mlx/ops.h"
@ -1027,303 +1024,4 @@ std::vector<Shape> AffineQuantize::output_shapes(
} }
} }
std::string write_signature(
std::string func_name,
const std::string& header,
const std::string& source,
const std::vector<std::string>& input_names,
const std::vector<array>& inputs,
const std::vector<std::string>& output_names,
const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
const std::vector<std::string>& attributes,
const std::vector<CustomKernelShapeInfo>& shape_infos,
bool atomic_outputs) {
std::string kernel_source;
kernel_source.reserve(header.size() + source.size() + 16384);
kernel_source += header;
// Auto-generate a function signature based on `template_args`
// and the dtype/shape of the arrays passed as `inputs`.
if (!template_args.empty()) {
kernel_source += "template <";
int i = 0;
for (const auto& [name, arg] : template_args) {
std::string param_type;
if (std::holds_alternative<int>(arg)) {
param_type = "int";
} else if (std::holds_alternative<bool>(arg)) {
param_type = "bool";
} else if (std::holds_alternative<Dtype>(arg)) {
param_type = "typename";
}
if (i > 0) {
kernel_source += ", ";
}
kernel_source += param_type;
kernel_source += " ";
kernel_source += name;
i++;
}
kernel_source += ">\n";
}
kernel_source += "[[kernel]] void ";
kernel_source += func_name;
kernel_source += "(\n";
int index = 0;
constexpr int max_constant_array_size = 8;
// Add inputs
for (int i = 0; i < inputs.size(); ++i) {
const auto& name = input_names[i];
const auto& arr = inputs[i];
auto dtype = get_type_string(arr.dtype());
std::string location =
arr.size() < max_constant_array_size ? "constant" : "device";
std::string ref = arr.ndim() == 0 ? "&" : "*";
kernel_source += " const ";
kernel_source += location;
kernel_source += " ";
kernel_source += dtype;
kernel_source += ref;
kernel_source += " ";
kernel_source += name;
kernel_source += " [[buffer(";
kernel_source += std::to_string(index);
kernel_source += ")]],\n";
index++;
// Add input shape, strides and ndim if present in the source
if (arr.ndim() > 0) {
if (shape_infos[i].shape) {
kernel_source +=
(" const constant int* " + name + "_shape [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
if (shape_infos[i].strides) {
kernel_source +=
(" const constant int64_t* " + name + "_strides [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
if (shape_infos[i].ndim) {
kernel_source +=
(" const constant int& " + name + "_ndim [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
}
}
// Add outputs
for (int i = 0; i < output_names.size(); ++i) {
const auto& name = output_names[i];
const auto& dtype = output_dtypes[i];
kernel_source += " device ";
auto type_string = get_type_string(dtype);
if (atomic_outputs) {
kernel_source += "atomic<";
}
kernel_source += type_string;
if (atomic_outputs) {
kernel_source += ">";
}
kernel_source += "* ";
kernel_source += name;
kernel_source += " [[buffer(";
kernel_source += std::to_string(index);
kernel_source += ")]]";
if (index < inputs.size() + output_names.size() - 1 ||
attributes.size() > 0) {
kernel_source += ",\n";
} else {
kernel_source += ") {\n";
}
index++;
}
index = 0;
for (const auto& attr : attributes) {
kernel_source += attr;
if (index < attributes.size() - 1) {
kernel_source += ",\n";
} else {
kernel_source += ") {\n";
}
index++;
}
kernel_source += source;
kernel_source += "\n}\n";
return kernel_source;
}
std::string write_template(
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
std::ostringstream template_def;
template_def << "<";
int i = 0;
for (const auto& [name, arg] : template_args) {
if (i > 0) {
template_def << ", ";
}
if (std::holds_alternative<int>(arg)) {
template_def << std::get<int>(arg);
} else if (std::holds_alternative<bool>(arg)) {
template_def << std::get<bool>(arg);
} else if (std::holds_alternative<Dtype>(arg)) {
template_def << get_type_string(std::get<Dtype>(arg));
}
i++;
}
template_def << ">";
return template_def.str();
}
MetalKernelFunction metal_kernel(
const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header /* = "" */,
bool ensure_row_contiguous /* = true */,
bool atomic_outputs /* = false */) {
if (output_names.empty()) {
throw std::invalid_argument(
"[metal_kernel] Must specify at least one output.");
}
std::vector<CustomKernelShapeInfo> shape_infos;
for (auto& n : input_names) {
CustomKernelShapeInfo shape_info;
shape_info.shape = source.find(n + "_shape") != std::string::npos;
shape_info.strides = source.find(n + "_strides") != std::string::npos;
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
shape_infos.push_back(shape_info);
}
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
{"dispatch_quadgroups_per_threadgroup", "uint"},
{"dispatch_simdgroups_per_threadgroup", "uint"},
{"dispatch_threads_per_threadgroup", "uint3"},
{"grid_origin", "uint3"},
{"grid_size", "uint3"},
{"quadgroup_index_in_threadgroup", "uint"},
{"quadgroups_per_threadgroup", "uint"},
{"simdgroup_index_in_threadgroup", "uint"},
{"simdgroups_per_threadgroup", "uint"},
{"thread_execution_width", "uint"},
{"thread_index_in_quadgroup", "uint"},
{"thread_index_in_simdgroup", "uint"},
{"thread_index_in_threadgroup", "uint"},
{"thread_position_in_grid", "uint3"},
{"thread_position_in_threadgroup", "uint3"},
{"threadgroup_position_in_grid", "uint3"},
{"threadgroups_per_grid", "uint3"},
{"threads_per_grid", "uint3"},
{"threads_per_simdgroup", "uint"},
{"threads_per_threadgroup", "uint3"},
};
std::vector<std::string> attributes;
for (const auto& [attr, dtype] : metal_attributes) {
if (source.find(attr) != std::string::npos) {
attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]");
}
}
return [=,
shape_infos = std::move(shape_infos),
attributes = std::move(attributes)](
const std::vector<array>& inputs,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::vector<std::pair<std::string, TemplateArg>>&
template_args = {},
std::optional<float> init_value = std::nullopt,
bool verbose = false,
StreamOrDevice s_ = {}) {
if (inputs.size() != input_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `inputs` to have size "
<< input_names.size() << " but got size " << inputs.size() << "."
<< std::endl;
throw std::invalid_argument(msg.str());
}
if (output_shapes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `output_shapes` to have size "
<< output_names.size() << " but got size " << output_shapes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
if (output_dtypes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `output_dtypes` to have size "
<< output_names.size() << " but got size " << output_dtypes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
auto s = to_stream(s_);
if (s.device != Device::gpu) {
throw std::invalid_argument("[metal_kernel] Only supports the GPU.");
}
std::ostringstream func_name;
std::string template_def = "";
std::string hash_key = "";
if (!template_args.empty()) {
std::regex disallowed_chars("\\<|\\>|(, )");
template_def = write_template(template_args);
hash_key = std::regex_replace(template_def, disallowed_chars, "_");
hash_key.pop_back();
}
func_name << "custom_kernel_" << name << hash_key;
std::string kernel_name = func_name.str();
std::string kernel_source = write_signature(
kernel_name,
header,
source,
input_names,
inputs,
output_names,
output_dtypes,
template_args,
attributes,
shape_infos,
atomic_outputs);
if (!template_args.empty()) {
template_def = kernel_name + template_def;
kernel_source += "\ntemplate [[host_name(\"";
kernel_source += kernel_name;
kernel_source += "\")]] [[kernel]] decltype(";
kernel_source += template_def;
kernel_source += ") ";
kernel_source += template_def;
kernel_source += ";\n";
}
if (verbose) {
std::cout << "Generated source code for `" << name << "`:" << std::endl
<< "```" << std::endl
<< kernel_source << std::endl
<< "```" << std::endl;
}
return array::make_arrays(
std::move(output_shapes),
std::move(output_dtypes),
std::make_shared<CustomKernel>(
s,
std::move(kernel_name),
std::move(kernel_source),
grid,
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value),
std::move(inputs));
};
}
} // namespace mlx::core::fast } // namespace mlx::core::fast

View File

@ -735,6 +735,41 @@ class TestFast(mlx_tests.MLXTestCase):
)[0] )[0]
self.assertEqual(out.item(), 2) self.assertEqual(out.item(), 2)
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_custom_kernel_caching(self):
def call_kernel(a: mx.array, source):
kernel = mx.fast.metal_kernel(
name="my_kernel",
input_names=["inp"],
output_names=["out"],
source=source,
)
return kernel(
inputs=[a],
grid=(a.size, 1, 1),
threadgroup=(a.size, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
stream=mx.gpu,
)[0]
a = mx.random.normal(shape=(32,))
source = """
uint elem = thread_position_in_grid.x;
out[elem] = 0.0;
"""
out = call_kernel(a, source)
self.assertTrue(mx.array_equal(out, mx.zeros_like(out)))
source = """
uint elem = thread_position_in_grid.x;
out[elem] = 1.0;
"""
out = call_kernel(a, source)
self.assertTrue(mx.array_equal(out, mx.ones_like(out)))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()