From ac1117b224351de3917c81306802ff04eab3482b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 3 Jun 2025 08:01:53 -0700 Subject: [PATCH] Fix unintuitive metal kernel caching --- docs/src/dev/custom_metal_kernels.rst | 498 +++++++++++++------------- mlx/backend/metal/custom_kernel.cpp | 10 + mlx/fast.cpp | 14 +- python/tests/test_fast.py | 35 ++ 4 files changed, 313 insertions(+), 244 deletions(-) diff --git a/docs/src/dev/custom_metal_kernels.rst b/docs/src/dev/custom_metal_kernels.rst index 3e92f2814..873b1e544 100644 --- a/docs/src/dev/custom_metal_kernels.rst +++ b/docs/src/dev/custom_metal_kernels.rst @@ -8,23 +8,26 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs. Simple Example -------------- +.. currentmodule:: mlx.core + Let's write a custom kernel that computes ``exp`` elementwise: .. code-block:: python - def exp_elementwise(a: mx.array): - source = """ - uint elem = thread_position_in_grid.x; - T tmp = inp[elem]; - out[elem] = metal::exp(tmp); - """ + source = """ + uint elem = thread_position_in_grid.x; + T tmp = inp[elem]; + out[elem] = metal::exp(tmp); + """ - kernel = mx.fast.metal_kernel( - name="myexp", - input_names=["inp"], - output_names=["out"], - source=source, - ) + kernel = mx.fast.metal_kernel( + name="myexp", + input_names=["inp"], + output_names=["out"], + source=source, + ) + + def exp_elementwise(a: mx.array): outputs = kernel( inputs=[a], template=[("T", mx.float32)], @@ -39,8 +42,13 @@ Let's write a custom kernel that computes ``exp`` elementwise: b = exp_elementwise(a) assert mx.allclose(b, mx.exp(a)) +Every time you make a kernel, a new Metal library is created and possibly +JIT compiled. To reduce the overhead from that, build the kernel once with +:func:`fast.metal_kernel` and then use it many times. + .. note:: - We are only required to pass the body of the Metal kernel in ``source``. + Only pass the body of the Metal kernel in ``source``. The function + signature is generated automatically. The full function signature will be generated using: @@ -78,44 +86,51 @@ Putting this all together, the generated function signature for ``myexp`` is as template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float) custom_kernel_myexp_float; -Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads `_ function. -This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups. -For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension. +Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads +`_ +function. This means we will launch ``mx.prod(grid)`` threads, subdivided into +``threadgroup`` size threadgroups. For optimal performance, each thread group +dimension should be less than or equal to the corresponding grid dimension. -Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes. +Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the +generated code for debugging purposes. Using Shape/Strides ------------------- -``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default. -This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous. -Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims -when indexing. +:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which +is ``True`` by default. This will copy the array inputs if needed +before the kernel is launched to ensure that the memory layout is row +contiguous. Generally this makes writing the kernel easier, since we don't +have to worry about gaps or the ordering of the dims when indexing. -If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each -input array ``a`` if any are present in ``source``. -We can then use MLX's built in indexing utils to fetch the right elements for each thread. +If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes +``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are +present in ``source``. We can then use MLX's built in indexing utils to fetch +the right elements for each thread. -Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``: +Let's convert ``myexp`` above to support arbitrarily strided arrays without +relying on a copy from ``ensure_row_contiguous``: .. code-block:: python + + source = """ + uint elem = thread_position_in_grid.x; + // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included + uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); + T tmp = inp[loc]; + // Output arrays are always row contiguous + out[elem] = metal::exp(tmp); + """ + + kernel = mx.fast.metal_kernel( + name="myexp_strided", + input_names=["inp"], + output_names=["out"], + source=source + ) def exp_elementwise(a: mx.array): - source = """ - uint elem = thread_position_in_grid.x; - // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included - uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); - T tmp = inp[loc]; - // Output arrays are always row contiguous - out[elem] = metal::exp(tmp); - """ - - kernel = mx.fast.metal_kernel( - name="myexp_strided", - input_names=["inp"], - output_names=["out"], - source=source - ) outputs = kernel( inputs=[a], template=[("T", mx.float32)], @@ -142,137 +157,139 @@ We'll start with the following MLX implementation using standard ops: .. code-block:: python - def grid_sample_ref(x, grid): - N, H_in, W_in, _ = x.shape - ix = ((grid[..., 0] + 1) * W_in - 1) / 2 - iy = ((grid[..., 1] + 1) * H_in - 1) / 2 + def grid_sample_ref(x, grid): + N, H_in, W_in, _ = x.shape + ix = ((grid[..., 0] + 1) * W_in - 1) / 2 + iy = ((grid[..., 1] + 1) * H_in - 1) / 2 - ix_nw = mx.floor(ix).astype(mx.int32) - iy_nw = mx.floor(iy).astype(mx.int32) + ix_nw = mx.floor(ix).astype(mx.int32) + iy_nw = mx.floor(iy).astype(mx.int32) - ix_ne = ix_nw + 1 - iy_ne = iy_nw + ix_ne = ix_nw + 1 + iy_ne = iy_nw - ix_sw = ix_nw - iy_sw = iy_nw + 1 + ix_sw = ix_nw + iy_sw = iy_nw + 1 - ix_se = ix_nw + 1 - iy_se = iy_nw + 1 + ix_se = ix_nw + 1 + iy_se = iy_nw + 1 - nw = (ix_se - ix) * (iy_se - iy) - ne = (ix - ix_sw) * (iy_sw - iy) - sw = (ix_ne - ix) * (iy - iy_ne) - se = (ix - ix_nw) * (iy - iy_nw) + nw = (ix_se - ix) * (iy_se - iy) + ne = (ix - ix_sw) * (iy_sw - iy) + sw = (ix_ne - ix) * (iy - iy_ne) + se = (ix - ix_nw) * (iy - iy_nw) - I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :] - I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :] - I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :] - I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :] + I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :] + I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :] + I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :] + I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :] - mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1) - mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1) - mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1) - mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1) + mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1) + mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1) + mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1) + mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1) - I_nw *= mask_nw[..., None] - I_ne *= mask_ne[..., None] - I_sw *= mask_sw[..., None] - I_se *= mask_se[..., None] + I_nw *= mask_nw[..., None] + I_ne *= mask_ne[..., None] + I_sw *= mask_sw[..., None] + I_se *= mask_se[..., None] - output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se + output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se - return output + return output -Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel`` +Now let's use :func:`custom_function` together with :func:`fast.metal_kernel` to write a fast GPU kernel for both the forward and backward passes. First we'll implement the forward pass as a fused kernel: .. code-block:: python - @mx.custom_function - def grid_sample(x, grid): + source = """ + uint elem = thread_position_in_grid.x; + int H = x_shape[1]; + int W = x_shape[2]; + int C = x_shape[3]; + int gH = grid_shape[1]; + int gW = grid_shape[2]; - assert x.ndim == 4, "`x` must be 4D." - assert grid.ndim == 4, "`grid` must be 4D." + int w_stride = C; + int h_stride = W * w_stride; + int b_stride = H * h_stride; - B, _, _, C = x.shape - _, gN, gM, D = grid.shape - out_shape = (B, gN, gM, C) + uint grid_idx = elem / C * 2; + float ix = ((grid[grid_idx] + 1) * W - 1) / 2; + float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; - assert D == 2, "Last dim of `grid` must be size 2." + int ix_nw = floor(ix); + int iy_nw = floor(iy); - source = """ - uint elem = thread_position_in_grid.x; - int H = x_shape[1]; - int W = x_shape[2]; - int C = x_shape[3]; - int gH = grid_shape[1]; - int gW = grid_shape[2]; + int ix_ne = ix_nw + 1; + int iy_ne = iy_nw; - int w_stride = C; - int h_stride = W * w_stride; - int b_stride = H * h_stride; + int ix_sw = ix_nw; + int iy_sw = iy_nw + 1; - uint grid_idx = elem / C * 2; - float ix = ((grid[grid_idx] + 1) * W - 1) / 2; - float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; + int ix_se = ix_nw + 1; + int iy_se = iy_nw + 1; - int ix_nw = floor(ix); - int iy_nw = floor(iy); + T nw = (ix_se - ix) * (iy_se - iy); + T ne = (ix - ix_sw) * (iy_sw - iy); + T sw = (ix_ne - ix) * (iy - iy_ne); + T se = (ix - ix_nw) * (iy - iy_nw); - int ix_ne = ix_nw + 1; - int iy_ne = iy_nw; + int batch_idx = elem / C / gH / gW * b_stride; + int channel_idx = elem % C; + int base_idx = batch_idx + channel_idx; - int ix_sw = ix_nw; - int iy_sw = iy_nw + 1; + T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride]; + T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride]; + T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride]; + T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride]; - int ix_se = ix_nw + 1; - int iy_se = iy_nw + 1; + I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0; + I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0; + I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0; + I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0; - T nw = (ix_se - ix) * (iy_se - iy); - T ne = (ix - ix_sw) * (iy_sw - iy); - T sw = (ix_ne - ix) * (iy - iy_ne); - T se = (ix - ix_nw) * (iy - iy_nw); + out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se; + """ - int batch_idx = elem / C / gH / gW * b_stride; - int channel_idx = elem % C; - int base_idx = batch_idx + channel_idx; + kernel = mx.fast.metal_kernel( + name="grid_sample", + input_names=["x", "grid"], + output_names=["out"], + source=source, + ) - T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride]; - T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride]; - T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride]; - T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride]; + @mx.custom_function + def grid_sample(x, grid): - I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0; - I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0; - I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0; - I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0; + assert x.ndim == 4, "`x` must be 4D." + assert grid.ndim == 4, "`grid` must be 4D." - out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se; - """ - kernel = mx.fast.metal_kernel( - name="grid_sample", - input_names=["x", "grid"], - output_names=["out"], - source=source, - ) - outputs = kernel( - inputs=[x, grid], - template=[("T", x.dtype)], - output_shapes=[out_shape], - output_dtypes=[x.dtype], - grid=(np.prod(out_shape), 1, 1), - threadgroup=(256, 1, 1), - ) - return outputs[0] + B, _, _, C = x.shape + _, gN, gM, D = grid.shape + out_shape = (B, gN, gM, C) + + assert D == 2, "Last dim of `grid` must be size 2." + + outputs = kernel( + inputs=[x, grid], + template=[("T", x.dtype)], + output_shapes=[out_shape], + output_dtypes=[x.dtype], + grid=(np.prod(out_shape), 1, 1), + threadgroup=(256, 1, 1), + ) + return outputs[0] For a reasonably sized input such as: .. code-block:: python - x.shape = (8, 1024, 1024, 64) - grid.shape = (8, 256, 256, 2) + x.shape = (8, 1024, 1024, 64) + grid.shape = (8, 256, 256, 2) On an M1 Max, we see a big performance improvement: @@ -281,11 +298,11 @@ On an M1 Max, we see a big performance improvement: Grid Sample VJP --------------- -Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define -its custom vjp transform so MLX can differentiate it. +Since we decorated ``grid_sample`` with :func:`custom_function`, we can now +define its custom vjp transform so MLX can differentiate it. The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so -requires a few extra ``mx.fast.metal_kernel`` features: +requires a few extra :func:`fast.metal_kernel` features: * ``init_value=0`` Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel. @@ -299,128 +316,129 @@ We can then implement the backwards pass as follows: .. code-block:: python - @grid_sample.vjp - def grid_sample_vjp(primals, cotangent, _): - x, grid = primals - B, _, _, C = x.shape - _, gN, gM, D = grid.shape + source = """ + uint elem = thread_position_in_grid.x; + int H = x_shape[1]; + int W = x_shape[2]; + int C = x_shape[3]; + // Pad C to the nearest larger simdgroup size multiple + int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup; - assert D == 2, "Last dim of `grid` must be size 2." + int gH = grid_shape[1]; + int gW = grid_shape[2]; - source = """ - uint elem = thread_position_in_grid.x; - int H = x_shape[1]; - int W = x_shape[2]; - int C = x_shape[3]; - // Pad C to the nearest larger simdgroup size multiple - int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup; + int w_stride = C; + int h_stride = W * w_stride; + int b_stride = H * h_stride; - int gH = grid_shape[1]; - int gW = grid_shape[2]; + uint grid_idx = elem / C_padded * 2; + float ix = ((grid[grid_idx] + 1) * W - 1) / 2; + float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; - int w_stride = C; - int h_stride = W * w_stride; - int b_stride = H * h_stride; + int ix_nw = floor(ix); + int iy_nw = floor(iy); - uint grid_idx = elem / C_padded * 2; - float ix = ((grid[grid_idx] + 1) * W - 1) / 2; - float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; + int ix_ne = ix_nw + 1; + int iy_ne = iy_nw; - int ix_nw = floor(ix); - int iy_nw = floor(iy); + int ix_sw = ix_nw; + int iy_sw = iy_nw + 1; - int ix_ne = ix_nw + 1; - int iy_ne = iy_nw; + int ix_se = ix_nw + 1; + int iy_se = iy_nw + 1; - int ix_sw = ix_nw; - int iy_sw = iy_nw + 1; + T nw = (ix_se - ix) * (iy_se - iy); + T ne = (ix - ix_sw) * (iy_sw - iy); + T sw = (ix_ne - ix) * (iy - iy_ne); + T se = (ix - ix_nw) * (iy - iy_nw); - int ix_se = ix_nw + 1; - int iy_se = iy_nw + 1; + int batch_idx = elem / C_padded / gH / gW * b_stride; + int channel_idx = elem % C_padded; + int base_idx = batch_idx + channel_idx; - T nw = (ix_se - ix) * (iy_se - iy); - T ne = (ix - ix_sw) * (iy_sw - iy); - T sw = (ix_ne - ix) * (iy - iy_ne); - T se = (ix - ix_nw) * (iy - iy_nw); + T gix = T(0); + T giy = T(0); + if (channel_idx < C) { + int cot_index = elem / C_padded * C + channel_idx; + T cot = cotangent[cot_index]; + if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) { + int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride; + atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed); - int batch_idx = elem / C_padded / gH / gW * b_stride; - int channel_idx = elem % C_padded; - int base_idx = batch_idx + channel_idx; + T I_nw = x[offset]; + gix -= I_nw * (iy_se - iy) * cot; + giy -= I_nw * (ix_se - ix) * cot; + } + if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) { + int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride; + atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed); - T gix = T(0); - T giy = T(0); - if (channel_idx < C) { - int cot_index = elem / C_padded * C + channel_idx; - T cot = cotangent[cot_index]; - if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) { - int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride; - atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed); + T I_ne = x[offset]; + gix += I_ne * (iy_sw - iy) * cot; + giy -= I_ne * (ix - ix_sw) * cot; + } + if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) { + int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride; + atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed); - T I_nw = x[offset]; - gix -= I_nw * (iy_se - iy) * cot; - giy -= I_nw * (ix_se - ix) * cot; - } - if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) { - int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride; - atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed); + T I_sw = x[offset]; + gix -= I_sw * (iy - iy_ne) * cot; + giy += I_sw * (ix_ne - ix) * cot; + } + if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) { + int offset = base_idx + iy_se * h_stride + ix_se * w_stride; + atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed); - T I_ne = x[offset]; - gix += I_ne * (iy_sw - iy) * cot; - giy -= I_ne * (ix - ix_sw) * cot; - } - if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) { - int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride; - atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed); + T I_se = x[offset]; + gix += I_se * (iy - iy_nw) * cot; + giy += I_se * (ix - ix_nw) * cot; + } + } - T I_sw = x[offset]; - gix -= I_sw * (iy - iy_ne) * cot; - giy += I_sw * (ix_ne - ix) * cot; - } - if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) { - int offset = base_idx + iy_se * h_stride + ix_se * w_stride; - atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed); + T gix_mult = W / 2; + T giy_mult = H / 2; - T I_se = x[offset]; - gix += I_se * (iy - iy_nw) * cot; - giy += I_se * (ix - ix_nw) * cot; - } - } + // Reduce across each simdgroup first. + // This is much faster than relying purely on atomics. + gix = simd_sum(gix); + giy = simd_sum(giy); - T gix_mult = W / 2; - T giy_mult = H / 2; + if (thread_index_in_simdgroup == 0) { + atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed); + atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed); + } + """ + kernel = mx.fast.metal_kernel( + name="grid_sample_grad", + input_names=["x", "grid", "cotangent"], + output_names=["x_grad", "grid_grad"], + source=source, + atomic_outputs=True, + ) - // Reduce across each simdgroup first. - // This is much faster than relying purely on atomics. - gix = simd_sum(gix); - giy = simd_sum(giy); + @grid_sample.vjp + def grid_sample_vjp(primals, cotangent, _): + x, grid = primals + B, _, _, C = x.shape + _, gN, gM, D = grid.shape - if (thread_index_in_simdgroup == 0) { - atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed); - atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed); - } - """ - kernel = mx.fast.metal_kernel( - name="grid_sample_grad", - input_names=["x", "grid", "cotangent"], - output_names=["x_grad", "grid_grad"], - source=source, - atomic_outputs=True, - ) - # pad the output channels to simd group size - # so that our `simd_sum`s don't overlap. - simdgroup_size = 32 - C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size - grid_size = B * gN * gM * C_padded - outputs = kernel( - inputs=[x, grid, cotangent], - template=[("T", x.dtype)], - output_shapes=[x.shape, grid.shape], - output_dtypes=[x.dtype, x.dtype], - grid=(grid_size, 1, 1), - threadgroup=(256, 1, 1), - init_value=0, - ) - return outputs[0], outputs[1] + assert D == 2, "Last dim of `grid` must be size 2." + + # pad the output channels to simd group size + # so that our `simd_sum`s don't overlap. + simdgroup_size = 32 + C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size + grid_size = B * gN * gM * C_padded + outputs = kernel( + inputs=[x, grid, cotangent], + template=[("T", x.dtype)], + output_shapes=[x.shape, grid.shape], + output_dtypes=[x.dtype, x.dtype], + grid=(grid_size, 1, 1), + threadgroup=(256, 1, 1), + init_value=0, + ) + return outputs[0], outputs[1] There's an even larger speed up for the vjp: diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index ea4f258cc..0240126b1 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -73,6 +73,16 @@ void CustomKernel::eval_gpu( } const auto [tx, ty, tz] = threadgroup_; + auto tg_size = tx * ty * tz; + auto max_tg_size = kernel->maxTotalThreadsPerThreadgroup(); + if (tg_size > max_tg_size) { + std::ostringstream msg; + msg << "Thread group size (" << tg_size << ") is greater than " + << " the maximum allowed threads per threadgroup (" << max_tg_size + << ")."; + throw std::invalid_argument(msg.str()); + } + const auto [gx, gy, gz] = grid_; MTL::Size group_dims = MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz)); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index eab22f14d..7a86f8d18 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -1,5 +1,6 @@ // Copyright © 2023-2024 Apple Inc. #include +#include #include #include #include @@ -1228,6 +1229,10 @@ MetalKernelFunction metal_kernel( attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]"); } } + auto now = std::chrono::system_clock::now(); + int64_t timestamp = std::chrono::duration_cast( + now.time_since_epoch()) + .count(); return [=, shape_infos = std::move(shape_infos), @@ -1271,14 +1276,15 @@ MetalKernelFunction metal_kernel( std::ostringstream func_name; std::string template_def = ""; - std::string hash_key = ""; + std::string template_hash = ""; if (!template_args.empty()) { std::regex disallowed_chars("\\<|\\>|(, )"); template_def = write_template(template_args); - hash_key = std::regex_replace(template_def, disallowed_chars, "_"); - hash_key.pop_back(); + template_hash = std::regex_replace(template_def, disallowed_chars, "_"); + template_hash.pop_back(); } - func_name << "custom_kernel_" << name << hash_key; + func_name << "custom_kernel_" << name << "_" << template_hash << "_" + << timestamp; std::string kernel_name = func_name.str(); std::string kernel_source = write_signature( diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index 2c90a3755..59c2fc3ef 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -735,6 +735,41 @@ class TestFast(mlx_tests.MLXTestCase): )[0] self.assertEqual(out.item(), 2) + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_custom_kernel_caching(self): + def call_kernel(a: mx.array, source): + kernel = mx.fast.metal_kernel( + name="my_kernel", + input_names=["inp"], + output_names=["out"], + source=source, + ) + return kernel( + inputs=[a], + grid=(a.size, 1, 1), + threadgroup=(a.size, 1, 1), + output_shapes=[a.shape], + output_dtypes=[a.dtype], + stream=mx.gpu, + )[0] + + a = mx.random.normal(shape=(32,)) + + source = """ + uint elem = thread_position_in_grid.x; + out[elem] = 0.0; + """ + + out = call_kernel(a, source) + self.assertTrue(mx.array_equal(out, mx.zeros_like(out))) + + source = """ + uint elem = thread_position_in_grid.x; + out[elem] = 1.0; + """ + out = call_kernel(a, source) + self.assertTrue(mx.array_equal(out, mx.ones_like(out))) + if __name__ == "__main__": unittest.main()