Fix unintuitive metal kernel caching

This commit is contained in:
Awni Hannun 2025-06-03 08:01:53 -07:00
parent c6a20b427a
commit ac1117b224
4 changed files with 313 additions and 244 deletions

View File

@ -8,23 +8,26 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs.
Simple Example Simple Example
-------------- --------------
.. currentmodule:: mlx.core
Let's write a custom kernel that computes ``exp`` elementwise: Let's write a custom kernel that computes ``exp`` elementwise:
.. code-block:: python .. code-block:: python
def exp_elementwise(a: mx.array): source = """
source = """ uint elem = thread_position_in_grid.x;
uint elem = thread_position_in_grid.x; T tmp = inp[elem];
T tmp = inp[elem]; out[elem] = metal::exp(tmp);
out[elem] = metal::exp(tmp); """
"""
kernel = mx.fast.metal_kernel( kernel = mx.fast.metal_kernel(
name="myexp", name="myexp",
input_names=["inp"], input_names=["inp"],
output_names=["out"], output_names=["out"],
source=source, source=source,
) )
def exp_elementwise(a: mx.array):
outputs = kernel( outputs = kernel(
inputs=[a], inputs=[a],
template=[("T", mx.float32)], template=[("T", mx.float32)],
@ -39,8 +42,13 @@ Let's write a custom kernel that computes ``exp`` elementwise:
b = exp_elementwise(a) b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a)) assert mx.allclose(b, mx.exp(a))
Every time you make a kernel, a new Metal library is created and possibly
JIT compiled. To reduce the overhead from that, build the kernel once with
:func:`fast.metal_kernel` and then use it many times.
.. note:: .. note::
We are only required to pass the body of the Metal kernel in ``source``. Only pass the body of the Metal kernel in ``source``. The function
signature is generated automatically.
The full function signature will be generated using: The full function signature will be generated using:
@ -78,44 +86,51 @@ Putting this all together, the generated function signature for ``myexp`` is as
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>; template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function. Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups. <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension. function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
``threadgroup`` size threadgroups. For optimal performance, each thread group
dimension should be less than or equal to the corresponding grid dimension.
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes. Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
generated code for debugging purposes.
Using Shape/Strides Using Shape/Strides
------------------- -------------------
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default. :func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous. is ``True`` by default. This will copy the array inputs if needed
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims before the kernel is launched to ensure that the memory layout is row
when indexing. contiguous. Generally this makes writing the kernel easier, since we don't
have to worry about gaps or the ordering of the dims when indexing.
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
input array ``a`` if any are present in ``source``. ``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
We can then use MLX's built in indexing utils to fetch the right elements for each thread. present in ``source``. We can then use MLX's built in indexing utils to fetch
the right elements for each thread.
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``: Let's convert ``myexp`` above to support arbitrarily strided arrays without
relying on a copy from ``ensure_row_contiguous``:
.. code-block:: python .. code-block:: python
source = """
uint elem = thread_position_in_grid.x;
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
// Output arrays are always row contiguous
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source
)
def exp_elementwise(a: mx.array): def exp_elementwise(a: mx.array):
source = """
uint elem = thread_position_in_grid.x;
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
// Output arrays are always row contiguous
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source
)
outputs = kernel( outputs = kernel(
inputs=[a], inputs=[a],
template=[("T", mx.float32)], template=[("T", mx.float32)],
@ -142,137 +157,139 @@ We'll start with the following MLX implementation using standard ops:
.. code-block:: python .. code-block:: python
def grid_sample_ref(x, grid): def grid_sample_ref(x, grid):
N, H_in, W_in, _ = x.shape N, H_in, W_in, _ = x.shape
ix = ((grid[..., 0] + 1) * W_in - 1) / 2 ix = ((grid[..., 0] + 1) * W_in - 1) / 2
iy = ((grid[..., 1] + 1) * H_in - 1) / 2 iy = ((grid[..., 1] + 1) * H_in - 1) / 2
ix_nw = mx.floor(ix).astype(mx.int32) ix_nw = mx.floor(ix).astype(mx.int32)
iy_nw = mx.floor(iy).astype(mx.int32) iy_nw = mx.floor(iy).astype(mx.int32)
ix_ne = ix_nw + 1 ix_ne = ix_nw + 1
iy_ne = iy_nw iy_ne = iy_nw
ix_sw = ix_nw ix_sw = ix_nw
iy_sw = iy_nw + 1 iy_sw = iy_nw + 1
ix_se = ix_nw + 1 ix_se = ix_nw + 1
iy_se = iy_nw + 1 iy_se = iy_nw + 1
nw = (ix_se - ix) * (iy_se - iy) nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy) ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne) sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw) se = (ix - ix_nw) * (iy - iy_nw)
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :] I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :] I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :] I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :] I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1) mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1) mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1) mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1) mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
I_nw *= mask_nw[..., None] I_nw *= mask_nw[..., None]
I_ne *= mask_ne[..., None] I_ne *= mask_ne[..., None]
I_sw *= mask_sw[..., None] I_sw *= mask_sw[..., None]
I_se *= mask_se[..., None] I_se *= mask_se[..., None]
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
return output return output
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel`` Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
to write a fast GPU kernel for both the forward and backward passes. to write a fast GPU kernel for both the forward and backward passes.
First we'll implement the forward pass as a fused kernel: First we'll implement the forward pass as a fused kernel:
.. code-block:: python .. code-block:: python
@mx.custom_function source = """
def grid_sample(x, grid): uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
assert x.ndim == 4, "`x` must be 4D." int w_stride = C;
assert grid.ndim == 4, "`grid` must be 4D." int h_stride = W * w_stride;
int b_stride = H * h_stride;
B, _, _, C = x.shape uint grid_idx = elem / C * 2;
_, gN, gM, D = grid.shape float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
out_shape = (B, gN, gM, C) float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
assert D == 2, "Last dim of `grid` must be size 2." int ix_nw = floor(ix);
int iy_nw = floor(iy);
source = """ int ix_ne = ix_nw + 1;
uint elem = thread_position_in_grid.x; int iy_ne = iy_nw;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
int w_stride = C; int ix_sw = ix_nw;
int h_stride = W * w_stride; int iy_sw = iy_nw + 1;
int b_stride = H * h_stride;
uint grid_idx = elem / C * 2; int ix_se = ix_nw + 1;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2; int iy_se = iy_nw + 1;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_nw = floor(ix); T nw = (ix_se - ix) * (iy_se - iy);
int iy_nw = floor(iy); T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int ix_ne = ix_nw + 1; int batch_idx = elem / C / gH / gW * b_stride;
int iy_ne = iy_nw; int channel_idx = elem % C;
int base_idx = batch_idx + channel_idx;
int ix_sw = ix_nw; T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
int iy_sw = iy_nw + 1; T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
int ix_se = ix_nw + 1; I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
int iy_se = iy_nw + 1; I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
T nw = (ix_se - ix) * (iy_se - iy); out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
T ne = (ix - ix_sw) * (iy_sw - iy); """
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int batch_idx = elem / C / gH / gW * b_stride; kernel = mx.fast.metal_kernel(
int channel_idx = elem % C; name="grid_sample",
int base_idx = batch_idx + channel_idx; input_names=["x", "grid"],
output_names=["out"],
source=source,
)
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride]; @mx.custom_function
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride]; def grid_sample(x, grid):
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0; assert x.ndim == 4, "`x` must be 4D."
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0; assert grid.ndim == 4, "`grid` must be 4D."
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se; B, _, _, C = x.shape
""" _, gN, gM, D = grid.shape
kernel = mx.fast.metal_kernel( out_shape = (B, gN, gM, C)
name="grid_sample",
input_names=["x", "grid"], assert D == 2, "Last dim of `grid` must be size 2."
output_names=["out"],
source=source, outputs = kernel(
) inputs=[x, grid],
outputs = kernel( template=[("T", x.dtype)],
inputs=[x, grid], output_shapes=[out_shape],
template=[("T", x.dtype)], output_dtypes=[x.dtype],
output_shapes=[out_shape], grid=(np.prod(out_shape), 1, 1),
output_dtypes=[x.dtype], threadgroup=(256, 1, 1),
grid=(np.prod(out_shape), 1, 1), )
threadgroup=(256, 1, 1), return outputs[0]
)
return outputs[0]
For a reasonably sized input such as: For a reasonably sized input such as:
.. code-block:: python .. code-block:: python
x.shape = (8, 1024, 1024, 64) x.shape = (8, 1024, 1024, 64)
grid.shape = (8, 256, 256, 2) grid.shape = (8, 256, 256, 2)
On an M1 Max, we see a big performance improvement: On an M1 Max, we see a big performance improvement:
@ -281,11 +298,11 @@ On an M1 Max, we see a big performance improvement:
Grid Sample VJP Grid Sample VJP
--------------- ---------------
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
its custom vjp transform so MLX can differentiate it. define its custom vjp transform so MLX can differentiate it.
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
requires a few extra ``mx.fast.metal_kernel`` features: requires a few extra :func:`fast.metal_kernel` features:
* ``init_value=0`` * ``init_value=0``
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel. Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
@ -299,128 +316,129 @@ We can then implement the backwards pass as follows:
.. code-block:: python .. code-block:: python
@grid_sample.vjp source = """
def grid_sample_vjp(primals, cotangent, _): uint elem = thread_position_in_grid.x;
x, grid = primals int H = x_shape[1];
B, _, _, C = x.shape int W = x_shape[2];
_, gN, gM, D = grid.shape int C = x_shape[3];
// Pad C to the nearest larger simdgroup size multiple
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
assert D == 2, "Last dim of `grid` must be size 2." int gH = grid_shape[1];
int gW = grid_shape[2];
source = """ int w_stride = C;
uint elem = thread_position_in_grid.x; int h_stride = W * w_stride;
int H = x_shape[1]; int b_stride = H * h_stride;
int W = x_shape[2];
int C = x_shape[3];
// Pad C to the nearest larger simdgroup size multiple
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
int gH = grid_shape[1]; uint grid_idx = elem / C_padded * 2;
int gW = grid_shape[2]; float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int w_stride = C; int ix_nw = floor(ix);
int h_stride = W * w_stride; int iy_nw = floor(iy);
int b_stride = H * h_stride;
uint grid_idx = elem / C_padded * 2; int ix_ne = ix_nw + 1;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2; int iy_ne = iy_nw;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_nw = floor(ix); int ix_sw = ix_nw;
int iy_nw = floor(iy); int iy_sw = iy_nw + 1;
int ix_ne = ix_nw + 1; int ix_se = ix_nw + 1;
int iy_ne = iy_nw; int iy_se = iy_nw + 1;
int ix_sw = ix_nw; T nw = (ix_se - ix) * (iy_se - iy);
int iy_sw = iy_nw + 1; T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int ix_se = ix_nw + 1; int batch_idx = elem / C_padded / gH / gW * b_stride;
int iy_se = iy_nw + 1; int channel_idx = elem % C_padded;
int base_idx = batch_idx + channel_idx;
T nw = (ix_se - ix) * (iy_se - iy); T gix = T(0);
T ne = (ix - ix_sw) * (iy_sw - iy); T giy = T(0);
T sw = (ix_ne - ix) * (iy - iy_ne); if (channel_idx < C) {
T se = (ix - ix_nw) * (iy - iy_nw); int cot_index = elem / C_padded * C + channel_idx;
T cot = cotangent[cot_index];
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
int batch_idx = elem / C_padded / gH / gW * b_stride; T I_nw = x[offset];
int channel_idx = elem % C_padded; gix -= I_nw * (iy_se - iy) * cot;
int base_idx = batch_idx + channel_idx; giy -= I_nw * (ix_se - ix) * cot;
}
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
T gix = T(0); T I_ne = x[offset];
T giy = T(0); gix += I_ne * (iy_sw - iy) * cot;
if (channel_idx < C) { giy -= I_ne * (ix - ix_sw) * cot;
int cot_index = elem / C_padded * C + channel_idx; }
T cot = cotangent[cot_index]; if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) { int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride; atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
T I_nw = x[offset]; T I_sw = x[offset];
gix -= I_nw * (iy_se - iy) * cot; gix -= I_sw * (iy - iy_ne) * cot;
giy -= I_nw * (ix_se - ix) * cot; giy += I_sw * (ix_ne - ix) * cot;
} }
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) { if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride; int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed); atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
T I_ne = x[offset]; T I_se = x[offset];
gix += I_ne * (iy_sw - iy) * cot; gix += I_se * (iy - iy_nw) * cot;
giy -= I_ne * (ix - ix_sw) * cot; giy += I_se * (ix - ix_nw) * cot;
} }
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) { }
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
T I_sw = x[offset]; T gix_mult = W / 2;
gix -= I_sw * (iy - iy_ne) * cot; T giy_mult = H / 2;
giy += I_sw * (ix_ne - ix) * cot;
}
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
T I_se = x[offset]; // Reduce across each simdgroup first.
gix += I_se * (iy - iy_nw) * cot; // This is much faster than relying purely on atomics.
giy += I_se * (ix - ix_nw) * cot; gix = simd_sum(gix);
} giy = simd_sum(giy);
}
T gix_mult = W / 2; if (thread_index_in_simdgroup == 0) {
T giy_mult = H / 2; atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
}
"""
kernel = mx.fast.metal_kernel(
name="grid_sample_grad",
input_names=["x", "grid", "cotangent"],
output_names=["x_grad", "grid_grad"],
source=source,
atomic_outputs=True,
)
// Reduce across each simdgroup first. @grid_sample.vjp
// This is much faster than relying purely on atomics. def grid_sample_vjp(primals, cotangent, _):
gix = simd_sum(gix); x, grid = primals
giy = simd_sum(giy); B, _, _, C = x.shape
_, gN, gM, D = grid.shape
if (thread_index_in_simdgroup == 0) { assert D == 2, "Last dim of `grid` must be size 2."
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed); # pad the output channels to simd group size
} # so that our `simd_sum`s don't overlap.
""" simdgroup_size = 32
kernel = mx.fast.metal_kernel( C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
name="grid_sample_grad", grid_size = B * gN * gM * C_padded
input_names=["x", "grid", "cotangent"], outputs = kernel(
output_names=["x_grad", "grid_grad"], inputs=[x, grid, cotangent],
source=source, template=[("T", x.dtype)],
atomic_outputs=True, output_shapes=[x.shape, grid.shape],
) output_dtypes=[x.dtype, x.dtype],
# pad the output channels to simd group size grid=(grid_size, 1, 1),
# so that our `simd_sum`s don't overlap. threadgroup=(256, 1, 1),
simdgroup_size = 32 init_value=0,
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size )
grid_size = B * gN * gM * C_padded return outputs[0], outputs[1]
outputs = kernel(
inputs=[x, grid, cotangent],
template=[("T", x.dtype)],
output_shapes=[x.shape, grid.shape],
output_dtypes=[x.dtype, x.dtype],
grid=(grid_size, 1, 1),
threadgroup=(256, 1, 1),
init_value=0,
)
return outputs[0], outputs[1]
There's an even larger speed up for the vjp: There's an even larger speed up for the vjp:

View File

@ -73,6 +73,16 @@ void CustomKernel::eval_gpu(
} }
const auto [tx, ty, tz] = threadgroup_; const auto [tx, ty, tz] = threadgroup_;
auto tg_size = tx * ty * tz;
auto max_tg_size = kernel->maxTotalThreadsPerThreadgroup();
if (tg_size > max_tg_size) {
std::ostringstream msg;
msg << "Thread group size (" << tg_size << ") is greater than "
<< " the maximum allowed threads per threadgroup (" << max_tg_size
<< ").";
throw std::invalid_argument(msg.str());
}
const auto [gx, gy, gz] = grid_; const auto [gx, gy, gz] = grid_;
MTL::Size group_dims = MTL::Size group_dims =
MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz)); MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));

View File

@ -1,5 +1,6 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cassert> #include <cassert>
#include <chrono>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <regex> #include <regex>
@ -1228,6 +1229,10 @@ MetalKernelFunction metal_kernel(
attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]"); attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]");
} }
} }
auto now = std::chrono::system_clock::now();
int64_t timestamp = std::chrono::duration_cast<std::chrono::milliseconds>(
now.time_since_epoch())
.count();
return [=, return [=,
shape_infos = std::move(shape_infos), shape_infos = std::move(shape_infos),
@ -1271,14 +1276,15 @@ MetalKernelFunction metal_kernel(
std::ostringstream func_name; std::ostringstream func_name;
std::string template_def = ""; std::string template_def = "";
std::string hash_key = ""; std::string template_hash = "";
if (!template_args.empty()) { if (!template_args.empty()) {
std::regex disallowed_chars("\\<|\\>|(, )"); std::regex disallowed_chars("\\<|\\>|(, )");
template_def = write_template(template_args); template_def = write_template(template_args);
hash_key = std::regex_replace(template_def, disallowed_chars, "_"); template_hash = std::regex_replace(template_def, disallowed_chars, "_");
hash_key.pop_back(); template_hash.pop_back();
} }
func_name << "custom_kernel_" << name << hash_key; func_name << "custom_kernel_" << name << "_" << template_hash << "_"
<< timestamp;
std::string kernel_name = func_name.str(); std::string kernel_name = func_name.str();
std::string kernel_source = write_signature( std::string kernel_source = write_signature(

View File

@ -735,6 +735,41 @@ class TestFast(mlx_tests.MLXTestCase):
)[0] )[0]
self.assertEqual(out.item(), 2) self.assertEqual(out.item(), 2)
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_custom_kernel_caching(self):
def call_kernel(a: mx.array, source):
kernel = mx.fast.metal_kernel(
name="my_kernel",
input_names=["inp"],
output_names=["out"],
source=source,
)
return kernel(
inputs=[a],
grid=(a.size, 1, 1),
threadgroup=(a.size, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
stream=mx.gpu,
)[0]
a = mx.random.normal(shape=(32,))
source = """
uint elem = thread_position_in_grid.x;
out[elem] = 0.0;
"""
out = call_kernel(a, source)
self.assertTrue(mx.array_equal(out, mx.zeros_like(out)))
source = """
uint elem = thread_position_in_grid.x;
out[elem] = 1.0;
"""
out = call_kernel(a, source)
self.assertTrue(mx.array_equal(out, mx.ones_like(out)))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()