From 1ca616844bc7434cb0186a302ed1afc6167970b3 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 6 Jun 2025 20:08:15 -0700 Subject: [PATCH] Fix unintuitive metal kernel caching (#2242) * Fix unintuitive metal kernel caching * alternative solution --- docs/src/dev/custom_metal_kernels.rst | 498 +++++++++--------- docs/src/dev/extensions.rst | 6 +- examples/extensions/axpby/axpby.cpp | 6 +- mlx/backend/metal/conv.cpp | 2 +- mlx/backend/metal/custom_kernel.cpp | 344 +++++++++++- mlx/backend/metal/device.cpp | 66 ++- mlx/backend/metal/device.h | 16 +- mlx/backend/metal/nojit_kernels.cpp | 8 +- mlx/backend/metal/normalization.cpp | 4 +- .../metal/scaled_dot_product_attention.cpp | 6 +- mlx/backend/no_gpu/primitives.cpp | 13 + mlx/fast.cpp | 302 ----------- python/tests/test_fast.py | 35 ++ 13 files changed, 713 insertions(+), 593 deletions(-) diff --git a/docs/src/dev/custom_metal_kernels.rst b/docs/src/dev/custom_metal_kernels.rst index 3e92f2814..873b1e544 100644 --- a/docs/src/dev/custom_metal_kernels.rst +++ b/docs/src/dev/custom_metal_kernels.rst @@ -8,23 +8,26 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs. Simple Example -------------- +.. currentmodule:: mlx.core + Let's write a custom kernel that computes ``exp`` elementwise: .. code-block:: python - def exp_elementwise(a: mx.array): - source = """ - uint elem = thread_position_in_grid.x; - T tmp = inp[elem]; - out[elem] = metal::exp(tmp); - """ + source = """ + uint elem = thread_position_in_grid.x; + T tmp = inp[elem]; + out[elem] = metal::exp(tmp); + """ - kernel = mx.fast.metal_kernel( - name="myexp", - input_names=["inp"], - output_names=["out"], - source=source, - ) + kernel = mx.fast.metal_kernel( + name="myexp", + input_names=["inp"], + output_names=["out"], + source=source, + ) + + def exp_elementwise(a: mx.array): outputs = kernel( inputs=[a], template=[("T", mx.float32)], @@ -39,8 +42,13 @@ Let's write a custom kernel that computes ``exp`` elementwise: b = exp_elementwise(a) assert mx.allclose(b, mx.exp(a)) +Every time you make a kernel, a new Metal library is created and possibly +JIT compiled. To reduce the overhead from that, build the kernel once with +:func:`fast.metal_kernel` and then use it many times. + .. note:: - We are only required to pass the body of the Metal kernel in ``source``. + Only pass the body of the Metal kernel in ``source``. The function + signature is generated automatically. The full function signature will be generated using: @@ -78,44 +86,51 @@ Putting this all together, the generated function signature for ``myexp`` is as template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float) custom_kernel_myexp_float; -Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads `_ function. -This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups. -For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension. +Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads +`_ +function. This means we will launch ``mx.prod(grid)`` threads, subdivided into +``threadgroup`` size threadgroups. For optimal performance, each thread group +dimension should be less than or equal to the corresponding grid dimension. -Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes. +Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the +generated code for debugging purposes. Using Shape/Strides ------------------- -``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default. -This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous. -Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims -when indexing. +:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which +is ``True`` by default. This will copy the array inputs if needed +before the kernel is launched to ensure that the memory layout is row +contiguous. Generally this makes writing the kernel easier, since we don't +have to worry about gaps or the ordering of the dims when indexing. -If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each -input array ``a`` if any are present in ``source``. -We can then use MLX's built in indexing utils to fetch the right elements for each thread. +If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes +``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are +present in ``source``. We can then use MLX's built in indexing utils to fetch +the right elements for each thread. -Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``: +Let's convert ``myexp`` above to support arbitrarily strided arrays without +relying on a copy from ``ensure_row_contiguous``: .. code-block:: python + + source = """ + uint elem = thread_position_in_grid.x; + // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included + uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); + T tmp = inp[loc]; + // Output arrays are always row contiguous + out[elem] = metal::exp(tmp); + """ + + kernel = mx.fast.metal_kernel( + name="myexp_strided", + input_names=["inp"], + output_names=["out"], + source=source + ) def exp_elementwise(a: mx.array): - source = """ - uint elem = thread_position_in_grid.x; - // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included - uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); - T tmp = inp[loc]; - // Output arrays are always row contiguous - out[elem] = metal::exp(tmp); - """ - - kernel = mx.fast.metal_kernel( - name="myexp_strided", - input_names=["inp"], - output_names=["out"], - source=source - ) outputs = kernel( inputs=[a], template=[("T", mx.float32)], @@ -142,137 +157,139 @@ We'll start with the following MLX implementation using standard ops: .. code-block:: python - def grid_sample_ref(x, grid): - N, H_in, W_in, _ = x.shape - ix = ((grid[..., 0] + 1) * W_in - 1) / 2 - iy = ((grid[..., 1] + 1) * H_in - 1) / 2 + def grid_sample_ref(x, grid): + N, H_in, W_in, _ = x.shape + ix = ((grid[..., 0] + 1) * W_in - 1) / 2 + iy = ((grid[..., 1] + 1) * H_in - 1) / 2 - ix_nw = mx.floor(ix).astype(mx.int32) - iy_nw = mx.floor(iy).astype(mx.int32) + ix_nw = mx.floor(ix).astype(mx.int32) + iy_nw = mx.floor(iy).astype(mx.int32) - ix_ne = ix_nw + 1 - iy_ne = iy_nw + ix_ne = ix_nw + 1 + iy_ne = iy_nw - ix_sw = ix_nw - iy_sw = iy_nw + 1 + ix_sw = ix_nw + iy_sw = iy_nw + 1 - ix_se = ix_nw + 1 - iy_se = iy_nw + 1 + ix_se = ix_nw + 1 + iy_se = iy_nw + 1 - nw = (ix_se - ix) * (iy_se - iy) - ne = (ix - ix_sw) * (iy_sw - iy) - sw = (ix_ne - ix) * (iy - iy_ne) - se = (ix - ix_nw) * (iy - iy_nw) + nw = (ix_se - ix) * (iy_se - iy) + ne = (ix - ix_sw) * (iy_sw - iy) + sw = (ix_ne - ix) * (iy - iy_ne) + se = (ix - ix_nw) * (iy - iy_nw) - I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :] - I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :] - I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :] - I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :] + I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :] + I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :] + I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :] + I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :] - mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1) - mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1) - mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1) - mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1) + mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1) + mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1) + mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1) + mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1) - I_nw *= mask_nw[..., None] - I_ne *= mask_ne[..., None] - I_sw *= mask_sw[..., None] - I_se *= mask_se[..., None] + I_nw *= mask_nw[..., None] + I_ne *= mask_ne[..., None] + I_sw *= mask_sw[..., None] + I_se *= mask_se[..., None] - output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se + output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se - return output + return output -Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel`` +Now let's use :func:`custom_function` together with :func:`fast.metal_kernel` to write a fast GPU kernel for both the forward and backward passes. First we'll implement the forward pass as a fused kernel: .. code-block:: python - @mx.custom_function - def grid_sample(x, grid): + source = """ + uint elem = thread_position_in_grid.x; + int H = x_shape[1]; + int W = x_shape[2]; + int C = x_shape[3]; + int gH = grid_shape[1]; + int gW = grid_shape[2]; - assert x.ndim == 4, "`x` must be 4D." - assert grid.ndim == 4, "`grid` must be 4D." + int w_stride = C; + int h_stride = W * w_stride; + int b_stride = H * h_stride; - B, _, _, C = x.shape - _, gN, gM, D = grid.shape - out_shape = (B, gN, gM, C) + uint grid_idx = elem / C * 2; + float ix = ((grid[grid_idx] + 1) * W - 1) / 2; + float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; - assert D == 2, "Last dim of `grid` must be size 2." + int ix_nw = floor(ix); + int iy_nw = floor(iy); - source = """ - uint elem = thread_position_in_grid.x; - int H = x_shape[1]; - int W = x_shape[2]; - int C = x_shape[3]; - int gH = grid_shape[1]; - int gW = grid_shape[2]; + int ix_ne = ix_nw + 1; + int iy_ne = iy_nw; - int w_stride = C; - int h_stride = W * w_stride; - int b_stride = H * h_stride; + int ix_sw = ix_nw; + int iy_sw = iy_nw + 1; - uint grid_idx = elem / C * 2; - float ix = ((grid[grid_idx] + 1) * W - 1) / 2; - float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; + int ix_se = ix_nw + 1; + int iy_se = iy_nw + 1; - int ix_nw = floor(ix); - int iy_nw = floor(iy); + T nw = (ix_se - ix) * (iy_se - iy); + T ne = (ix - ix_sw) * (iy_sw - iy); + T sw = (ix_ne - ix) * (iy - iy_ne); + T se = (ix - ix_nw) * (iy - iy_nw); - int ix_ne = ix_nw + 1; - int iy_ne = iy_nw; + int batch_idx = elem / C / gH / gW * b_stride; + int channel_idx = elem % C; + int base_idx = batch_idx + channel_idx; - int ix_sw = ix_nw; - int iy_sw = iy_nw + 1; + T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride]; + T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride]; + T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride]; + T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride]; - int ix_se = ix_nw + 1; - int iy_se = iy_nw + 1; + I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0; + I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0; + I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0; + I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0; - T nw = (ix_se - ix) * (iy_se - iy); - T ne = (ix - ix_sw) * (iy_sw - iy); - T sw = (ix_ne - ix) * (iy - iy_ne); - T se = (ix - ix_nw) * (iy - iy_nw); + out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se; + """ - int batch_idx = elem / C / gH / gW * b_stride; - int channel_idx = elem % C; - int base_idx = batch_idx + channel_idx; + kernel = mx.fast.metal_kernel( + name="grid_sample", + input_names=["x", "grid"], + output_names=["out"], + source=source, + ) - T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride]; - T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride]; - T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride]; - T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride]; + @mx.custom_function + def grid_sample(x, grid): - I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0; - I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0; - I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0; - I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0; + assert x.ndim == 4, "`x` must be 4D." + assert grid.ndim == 4, "`grid` must be 4D." - out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se; - """ - kernel = mx.fast.metal_kernel( - name="grid_sample", - input_names=["x", "grid"], - output_names=["out"], - source=source, - ) - outputs = kernel( - inputs=[x, grid], - template=[("T", x.dtype)], - output_shapes=[out_shape], - output_dtypes=[x.dtype], - grid=(np.prod(out_shape), 1, 1), - threadgroup=(256, 1, 1), - ) - return outputs[0] + B, _, _, C = x.shape + _, gN, gM, D = grid.shape + out_shape = (B, gN, gM, C) + + assert D == 2, "Last dim of `grid` must be size 2." + + outputs = kernel( + inputs=[x, grid], + template=[("T", x.dtype)], + output_shapes=[out_shape], + output_dtypes=[x.dtype], + grid=(np.prod(out_shape), 1, 1), + threadgroup=(256, 1, 1), + ) + return outputs[0] For a reasonably sized input such as: .. code-block:: python - x.shape = (8, 1024, 1024, 64) - grid.shape = (8, 256, 256, 2) + x.shape = (8, 1024, 1024, 64) + grid.shape = (8, 256, 256, 2) On an M1 Max, we see a big performance improvement: @@ -281,11 +298,11 @@ On an M1 Max, we see a big performance improvement: Grid Sample VJP --------------- -Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define -its custom vjp transform so MLX can differentiate it. +Since we decorated ``grid_sample`` with :func:`custom_function`, we can now +define its custom vjp transform so MLX can differentiate it. The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so -requires a few extra ``mx.fast.metal_kernel`` features: +requires a few extra :func:`fast.metal_kernel` features: * ``init_value=0`` Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel. @@ -299,128 +316,129 @@ We can then implement the backwards pass as follows: .. code-block:: python - @grid_sample.vjp - def grid_sample_vjp(primals, cotangent, _): - x, grid = primals - B, _, _, C = x.shape - _, gN, gM, D = grid.shape + source = """ + uint elem = thread_position_in_grid.x; + int H = x_shape[1]; + int W = x_shape[2]; + int C = x_shape[3]; + // Pad C to the nearest larger simdgroup size multiple + int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup; - assert D == 2, "Last dim of `grid` must be size 2." + int gH = grid_shape[1]; + int gW = grid_shape[2]; - source = """ - uint elem = thread_position_in_grid.x; - int H = x_shape[1]; - int W = x_shape[2]; - int C = x_shape[3]; - // Pad C to the nearest larger simdgroup size multiple - int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup; + int w_stride = C; + int h_stride = W * w_stride; + int b_stride = H * h_stride; - int gH = grid_shape[1]; - int gW = grid_shape[2]; + uint grid_idx = elem / C_padded * 2; + float ix = ((grid[grid_idx] + 1) * W - 1) / 2; + float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; - int w_stride = C; - int h_stride = W * w_stride; - int b_stride = H * h_stride; + int ix_nw = floor(ix); + int iy_nw = floor(iy); - uint grid_idx = elem / C_padded * 2; - float ix = ((grid[grid_idx] + 1) * W - 1) / 2; - float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; + int ix_ne = ix_nw + 1; + int iy_ne = iy_nw; - int ix_nw = floor(ix); - int iy_nw = floor(iy); + int ix_sw = ix_nw; + int iy_sw = iy_nw + 1; - int ix_ne = ix_nw + 1; - int iy_ne = iy_nw; + int ix_se = ix_nw + 1; + int iy_se = iy_nw + 1; - int ix_sw = ix_nw; - int iy_sw = iy_nw + 1; + T nw = (ix_se - ix) * (iy_se - iy); + T ne = (ix - ix_sw) * (iy_sw - iy); + T sw = (ix_ne - ix) * (iy - iy_ne); + T se = (ix - ix_nw) * (iy - iy_nw); - int ix_se = ix_nw + 1; - int iy_se = iy_nw + 1; + int batch_idx = elem / C_padded / gH / gW * b_stride; + int channel_idx = elem % C_padded; + int base_idx = batch_idx + channel_idx; - T nw = (ix_se - ix) * (iy_se - iy); - T ne = (ix - ix_sw) * (iy_sw - iy); - T sw = (ix_ne - ix) * (iy - iy_ne); - T se = (ix - ix_nw) * (iy - iy_nw); + T gix = T(0); + T giy = T(0); + if (channel_idx < C) { + int cot_index = elem / C_padded * C + channel_idx; + T cot = cotangent[cot_index]; + if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) { + int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride; + atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed); - int batch_idx = elem / C_padded / gH / gW * b_stride; - int channel_idx = elem % C_padded; - int base_idx = batch_idx + channel_idx; + T I_nw = x[offset]; + gix -= I_nw * (iy_se - iy) * cot; + giy -= I_nw * (ix_se - ix) * cot; + } + if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) { + int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride; + atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed); - T gix = T(0); - T giy = T(0); - if (channel_idx < C) { - int cot_index = elem / C_padded * C + channel_idx; - T cot = cotangent[cot_index]; - if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) { - int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride; - atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed); + T I_ne = x[offset]; + gix += I_ne * (iy_sw - iy) * cot; + giy -= I_ne * (ix - ix_sw) * cot; + } + if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) { + int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride; + atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed); - T I_nw = x[offset]; - gix -= I_nw * (iy_se - iy) * cot; - giy -= I_nw * (ix_se - ix) * cot; - } - if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) { - int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride; - atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed); + T I_sw = x[offset]; + gix -= I_sw * (iy - iy_ne) * cot; + giy += I_sw * (ix_ne - ix) * cot; + } + if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) { + int offset = base_idx + iy_se * h_stride + ix_se * w_stride; + atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed); - T I_ne = x[offset]; - gix += I_ne * (iy_sw - iy) * cot; - giy -= I_ne * (ix - ix_sw) * cot; - } - if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) { - int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride; - atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed); + T I_se = x[offset]; + gix += I_se * (iy - iy_nw) * cot; + giy += I_se * (ix - ix_nw) * cot; + } + } - T I_sw = x[offset]; - gix -= I_sw * (iy - iy_ne) * cot; - giy += I_sw * (ix_ne - ix) * cot; - } - if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) { - int offset = base_idx + iy_se * h_stride + ix_se * w_stride; - atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed); + T gix_mult = W / 2; + T giy_mult = H / 2; - T I_se = x[offset]; - gix += I_se * (iy - iy_nw) * cot; - giy += I_se * (ix - ix_nw) * cot; - } - } + // Reduce across each simdgroup first. + // This is much faster than relying purely on atomics. + gix = simd_sum(gix); + giy = simd_sum(giy); - T gix_mult = W / 2; - T giy_mult = H / 2; + if (thread_index_in_simdgroup == 0) { + atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed); + atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed); + } + """ + kernel = mx.fast.metal_kernel( + name="grid_sample_grad", + input_names=["x", "grid", "cotangent"], + output_names=["x_grad", "grid_grad"], + source=source, + atomic_outputs=True, + ) - // Reduce across each simdgroup first. - // This is much faster than relying purely on atomics. - gix = simd_sum(gix); - giy = simd_sum(giy); + @grid_sample.vjp + def grid_sample_vjp(primals, cotangent, _): + x, grid = primals + B, _, _, C = x.shape + _, gN, gM, D = grid.shape - if (thread_index_in_simdgroup == 0) { - atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed); - atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed); - } - """ - kernel = mx.fast.metal_kernel( - name="grid_sample_grad", - input_names=["x", "grid", "cotangent"], - output_names=["x_grad", "grid_grad"], - source=source, - atomic_outputs=True, - ) - # pad the output channels to simd group size - # so that our `simd_sum`s don't overlap. - simdgroup_size = 32 - C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size - grid_size = B * gN * gM * C_padded - outputs = kernel( - inputs=[x, grid, cotangent], - template=[("T", x.dtype)], - output_shapes=[x.shape, grid.shape], - output_dtypes=[x.dtype, x.dtype], - grid=(grid_size, 1, 1), - threadgroup=(256, 1, 1), - init_value=0, - ) - return outputs[0], outputs[1] + assert D == 2, "Last dim of `grid` must be size 2." + + # pad the output channels to simd group size + # so that our `simd_sum`s don't overlap. + simdgroup_size = 32 + C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size + grid_size = B * gN * gM * C_padded + outputs = kernel( + inputs=[x, grid, cotangent], + template=[("T", x.dtype)], + output_shapes=[x.shape, grid.shape], + output_dtypes=[x.dtype, x.dtype], + grid=(grid_size, 1, 1), + threadgroup=(256, 1, 1), + init_value=0, + ) + return outputs[0], outputs[1] There's an even larger speed up for the vjp: diff --git a/docs/src/dev/extensions.rst b/docs/src/dev/extensions.rst index 2aef28f99..03f1c2163 100644 --- a/docs/src/dev/extensions.rst +++ b/docs/src/dev/extensions.rst @@ -397,11 +397,11 @@ below. std::ostringstream kname; kname << "axpby_" << "general_" << type_to_name(out); - // Make sure the metal library is available - d.register_library("mlx_ext"); + // Load the metal library + auto lib = d.get_library("mlx_ext"); // Make a kernel from this metal library - auto kernel = d.get_kernel(kname.str(), "mlx_ext"); + auto kernel = d.get_kernel(kname.str(), lib); // Prepare to encode kernel auto& compute_encoder = d.get_command_encoder(s.index); diff --git a/examples/extensions/axpby/axpby.cpp b/examples/extensions/axpby/axpby.cpp index 291246617..9ba933483 100644 --- a/examples/extensions/axpby/axpby.cpp +++ b/examples/extensions/axpby/axpby.cpp @@ -172,11 +172,11 @@ void Axpby::eval_gpu( kname << (contiguous_kernel ? "contiguous_" : "general_"); kname << type_to_name(out); - // Make sure the metal library is available - d.register_library("mlx_ext"); + // Load the metal library + auto lib = d.get_library("mlx_ext"); // Make a kernel from this metal library - auto kernel = d.get_kernel(kname.str(), "mlx_ext"); + auto kernel = d.get_kernel(kname.str(), lib); // Prepare to encode kernel auto& compute_encoder = d.get_command_encoder(s.index); diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 6b4b70d47..593b79384 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -677,7 +677,7 @@ void depthwise_conv_2D_gpu( std::string hash_name = kname.str(); auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); + auto kernel = d.get_kernel(base_name, hash_name, func_consts); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index ea4f258cc..161503a0e 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -1,12 +1,326 @@ // Copyright © 2024 Apple Inc. +#include +#include + +#include "mlx/backend/common/compiled.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/utils.h" +#include "mlx/fast.h" #include "mlx/fast_primitives.h" +#include "mlx/utils.h" namespace mlx::core::fast { +struct CustomKernelCache { + std::unordered_map libraries; +}; + +static CustomKernelCache& cache() { + static CustomKernelCache cache_; + return cache_; +}; + +std::string write_signature( + std::string func_name, + const std::string& header, + const std::string& source, + const std::vector& input_names, + const std::vector& inputs, + const std::vector& output_names, + const std::vector& output_dtypes, + const std::vector>& template_args, + const std::vector& attributes, + const std::vector& shape_infos, + bool atomic_outputs) { + std::string kernel_source; + kernel_source.reserve(header.size() + source.size() + 16384); + kernel_source += header; + // Auto-generate a function signature based on `template_args` + // and the dtype/shape of the arrays passed as `inputs`. + if (!template_args.empty()) { + kernel_source += "template <"; + int i = 0; + for (const auto& [name, arg] : template_args) { + std::string param_type; + if (std::holds_alternative(arg)) { + param_type = "int"; + } else if (std::holds_alternative(arg)) { + param_type = "bool"; + } else if (std::holds_alternative(arg)) { + param_type = "typename"; + } + if (i > 0) { + kernel_source += ", "; + } + kernel_source += param_type; + kernel_source += " "; + kernel_source += name; + i++; + } + kernel_source += ">\n"; + } + kernel_source += "[[kernel]] void "; + kernel_source += func_name; + kernel_source += "(\n"; + + int index = 0; + constexpr int max_constant_array_size = 8; + // Add inputs + for (int i = 0; i < inputs.size(); ++i) { + const auto& name = input_names[i]; + const auto& arr = inputs[i]; + auto dtype = get_type_string(arr.dtype()); + std::string location = + arr.size() < max_constant_array_size ? "constant" : "device"; + std::string ref = arr.ndim() == 0 ? "&" : "*"; + kernel_source += " const "; + kernel_source += location; + kernel_source += " "; + kernel_source += dtype; + kernel_source += ref; + kernel_source += " "; + kernel_source += name; + kernel_source += " [[buffer("; + kernel_source += std::to_string(index); + kernel_source += ")]],\n"; + index++; + // Add input shape, strides and ndim if present in the source + if (arr.ndim() > 0) { + if (shape_infos[i].shape) { + kernel_source += + (" const constant int* " + name + "_shape [[buffer(" + + std::to_string(index) + ")]],\n"); + index++; + } + if (shape_infos[i].strides) { + kernel_source += + (" const constant int64_t* " + name + "_strides [[buffer(" + + std::to_string(index) + ")]],\n"); + index++; + } + if (shape_infos[i].ndim) { + kernel_source += + (" const constant int& " + name + "_ndim [[buffer(" + + std::to_string(index) + ")]],\n"); + index++; + } + } + } + // Add outputs + for (int i = 0; i < output_names.size(); ++i) { + const auto& name = output_names[i]; + const auto& dtype = output_dtypes[i]; + kernel_source += " device "; + auto type_string = get_type_string(dtype); + if (atomic_outputs) { + kernel_source += "atomic<"; + } + kernel_source += type_string; + if (atomic_outputs) { + kernel_source += ">"; + } + kernel_source += "* "; + kernel_source += name; + kernel_source += " [[buffer("; + kernel_source += std::to_string(index); + kernel_source += ")]]"; + if (index < inputs.size() + output_names.size() - 1 || + attributes.size() > 0) { + kernel_source += ",\n"; + } else { + kernel_source += ") {\n"; + } + index++; + } + + index = 0; + for (const auto& attr : attributes) { + kernel_source += attr; + if (index < attributes.size() - 1) { + kernel_source += ",\n"; + } else { + kernel_source += ") {\n"; + } + index++; + } + kernel_source += source; + kernel_source += "\n}\n"; + return kernel_source; +} + +std::string write_template( + const std::vector>& template_args) { + std::ostringstream template_def; + template_def << "<"; + int i = 0; + for (const auto& [name, arg] : template_args) { + if (i > 0) { + template_def << ", "; + } + if (std::holds_alternative(arg)) { + template_def << std::get(arg); + } else if (std::holds_alternative(arg)) { + template_def << std::get(arg); + } else if (std::holds_alternative(arg)) { + template_def << get_type_string(std::get(arg)); + } + i++; + } + template_def << ">"; + return template_def.str(); +} + +MetalKernelFunction metal_kernel( + const std::string& name, + const std::vector& input_names, + const std::vector& output_names, + const std::string& source, + const std::string& header /* = "" */, + bool ensure_row_contiguous /* = true */, + bool atomic_outputs /* = false */) { + if (output_names.empty()) { + throw std::invalid_argument( + "[metal_kernel] Must specify at least one output."); + } + std::vector shape_infos; + for (auto& n : input_names) { + CustomKernelShapeInfo shape_info; + shape_info.shape = source.find(n + "_shape") != std::string::npos; + shape_info.strides = source.find(n + "_strides") != std::string::npos; + shape_info.ndim = source.find(n + "_ndim") != std::string::npos; + shape_infos.push_back(shape_info); + } + const std::vector> metal_attributes = { + {"dispatch_quadgroups_per_threadgroup", "uint"}, + {"dispatch_simdgroups_per_threadgroup", "uint"}, + {"dispatch_threads_per_threadgroup", "uint3"}, + {"grid_origin", "uint3"}, + {"grid_size", "uint3"}, + {"quadgroup_index_in_threadgroup", "uint"}, + {"quadgroups_per_threadgroup", "uint"}, + {"simdgroup_index_in_threadgroup", "uint"}, + {"simdgroups_per_threadgroup", "uint"}, + {"thread_execution_width", "uint"}, + {"thread_index_in_quadgroup", "uint"}, + {"thread_index_in_simdgroup", "uint"}, + {"thread_index_in_threadgroup", "uint"}, + {"thread_position_in_grid", "uint3"}, + {"thread_position_in_threadgroup", "uint3"}, + {"threadgroup_position_in_grid", "uint3"}, + {"threadgroups_per_grid", "uint3"}, + {"threads_per_grid", "uint3"}, + {"threads_per_simdgroup", "uint"}, + {"threads_per_threadgroup", "uint3"}, + }; + + std::vector attributes; + for (const auto& [attr, dtype] : metal_attributes) { + if (source.find(attr) != std::string::npos) { + attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]"); + } + } + + return [=, + shape_infos = std::move(shape_infos), + attributes = std::move(attributes)]( + const std::vector& inputs, + const std::vector& output_shapes, + const std::vector& output_dtypes, + std::tuple grid, + std::tuple threadgroup, + const std::vector>& + template_args = {}, + std::optional init_value = std::nullopt, + bool verbose = false, + StreamOrDevice s_ = {}) { + if (inputs.size() != input_names.size()) { + std::ostringstream msg; + msg << "[metal_kernel] Expected `inputs` to have size " + << input_names.size() << " but got size " << inputs.size() << "." + << std::endl; + throw std::invalid_argument(msg.str()); + } + if (output_shapes.size() != output_names.size()) { + std::ostringstream msg; + msg << "[metal_kernel] Expected `output_shapes` to have size " + << output_names.size() << " but got size " << output_shapes.size() + << "." << std::endl; + throw std::invalid_argument(msg.str()); + } + if (output_dtypes.size() != output_names.size()) { + std::ostringstream msg; + msg << "[metal_kernel] Expected `output_dtypes` to have size " + << output_names.size() << " but got size " << output_dtypes.size() + << "." << std::endl; + throw std::invalid_argument(msg.str()); + } + + auto s = to_stream(s_); + if (s.device != Device::gpu) { + throw std::invalid_argument("[metal_kernel] Only supports the GPU."); + } + + std::string kernel_name = "custom_kernel_" + name; + std::string template_def = ""; + if (!template_args.empty()) { + std::regex disallowed_chars("\\<|\\>|(, )"); + template_def = write_template(template_args); + auto template_hash = + std::regex_replace(template_def, disallowed_chars, "_"); + template_hash.pop_back(); + kernel_name += "_"; + kernel_name += template_hash; + } + + std::string kernel_source = write_signature( + kernel_name, + header, + source, + input_names, + inputs, + output_names, + output_dtypes, + template_args, + attributes, + shape_infos, + atomic_outputs); + + if (!template_args.empty()) { + template_def = kernel_name + template_def; + kernel_source += "\ntemplate [[host_name(\""; + kernel_source += kernel_name; + kernel_source += "\")]] [[kernel]] decltype("; + kernel_source += template_def; + kernel_source += ") "; + kernel_source += template_def; + kernel_source += ";\n"; + } + + if (verbose) { + std::cout << "Generated source code for `" << name << "`:" << std::endl + << "```" << std::endl + << kernel_source << std::endl + << "```" << std::endl; + } + + return array::make_arrays( + std::move(output_shapes), + std::move(output_dtypes), + std::make_shared( + s, + std::move(kernel_name), + std::move(kernel_source), + grid, + threadgroup, + shape_infos, + ensure_row_contiguous, + init_value), + std::move(inputs)); + }; +} + void CustomKernel::eval_gpu( const std::vector& inputs, std::vector& outputs) { @@ -39,9 +353,23 @@ void CustomKernel::eval_gpu( } auto& d = metal::device(s.device); - const auto& lib_name = name_; - auto lib = - d.get_library(lib_name, [this] { return metal::utils() + source_; }); + + { + // Clear kernels from the device library cache if needed + auto& kernel_cache = cache(); + if (auto it = kernel_cache.libraries.find(name_); + it != kernel_cache.libraries.end()) { + if (it->second != source_) { + auto& d = metal::device(s.device); + d.clear_library(name_); + it->second = source_; + } + } else { + kernel_cache.libraries.emplace(name_, source_); + } + } + + auto lib = d.get_library(name_, [this] { return metal::utils() + source_; }); auto kernel = d.get_kernel(name_, lib); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); @@ -73,6 +401,16 @@ void CustomKernel::eval_gpu( } const auto [tx, ty, tz] = threadgroup_; + auto tg_size = tx * ty * tz; + auto max_tg_size = kernel->maxTotalThreadsPerThreadgroup(); + if (tg_size > max_tg_size) { + std::ostringstream msg; + msg << "Thread group size (" << tg_size << ") is greater than " + << " the maximum allowed threads per threadgroup (" << max_tg_size + << ")."; + throw std::invalid_argument(msg.str()); + } + const auto [gx, gy, gz] = grid_; MTL::Size group_dims = MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz)); diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index ebc3cc77f..425274361 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -295,7 +295,7 @@ void CommandEncoder::barrier() { Device::Device() { auto pool = new_scoped_memory_pool(); device_ = load_device(); - library_map_ = {{"mlx", load_default_library(device_)}}; + default_library_ = load_default_library(device_); arch_ = std::string(device_->architecture()->name()->utf8String()); auto arch = arch_.back(); switch (arch) { @@ -326,11 +326,11 @@ Device::Device() { Device::~Device() { auto pool = new_scoped_memory_pool(); - for (auto& k : kernel_map_) { - k.second->release(); - } - for (auto& l : library_map_) { - l.second->release(); + for (auto& [l, kernel_map] : library_kernels_) { + l->release(); + for (auto& [_, k] : kernel_map) { + k->release(); + } } stream_map_.clear(); device_->release(); @@ -474,13 +474,24 @@ CommandEncoder& Device::get_command_encoder(int index) { return *stream.encoder; } -void Device::register_library( - const std::string& lib_name, - const std::string& lib_path) { - if (auto it = library_map_.find(lib_name); it == library_map_.end()) { - auto new_lib = load_library(device_, lib_name, lib_path.c_str()); - library_map_.insert({lib_name, new_lib}); +MTL::Library* Device::get_library( + const std::string& name, + const std::string& path /* = "" */) { + { + std::shared_lock rlock(library_mtx_); + if (auto it = library_map_.find(name); it != library_map_.end()) { + return it->second; + } } + + std::unique_lock wlock(library_mtx_); + if (auto it = library_map_.find(name); it != library_map_.end()) { + return it->second; + } + + auto new_lib = load_library(device_, name, path.c_str()); + library_map_.insert({name, new_lib}); + return new_lib; } MTL::Library* Device::build_library_(const std::string& source_string) { @@ -649,6 +660,19 @@ MTL::Library* Device::get_library( return mtl_lib; } +void Device::clear_library(const std::string& name) { + std::unique_lock wlock(library_mtx_); + if (auto it = library_map_.find(name); it != library_map_.end()) { + auto kernel_map_it = library_kernels_.find(it->second); + for (auto& [_, kernel] : kernel_map_it->second) { + kernel->release(); + } + library_kernels_.erase(kernel_map_it); + it->second->release(); + library_map_.erase(it); + } +} + MTL::LinkedFunctions* Device::get_linked_functions_( const std::vector& funcs) { if (funcs.empty()) { @@ -679,6 +703,7 @@ MTL::ComputePipelineState* Device::get_kernel_( std::unique_lock wlock(kernel_mtx_); // Try loading again to avoid loading twice + auto& kernel_map_ = library_kernels_[mtl_lib]; if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) { return it->second; } @@ -713,6 +738,7 @@ MTL::ComputePipelineState* Device::get_kernel( std::shared_lock lock(kernel_mtx_); // Look for cached kernel + auto& kernel_map_ = library_kernels_[mtl_lib]; if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) { return it->second; } @@ -722,23 +748,11 @@ MTL::ComputePipelineState* Device::get_kernel( MTL::ComputePipelineState* Device::get_kernel( const std::string& base_name, - const std::string& lib_name /* = "mlx" */, const std::string& hash_name /* = "" */, const MTLFCList& func_consts /* = {} */, const std::vector& linked_functions /* = {} */) { - const auto& kname = hash_name.size() == 0 ? base_name : hash_name; - { - // Multiple readers allowed - std::shared_lock lock(kernel_mtx_); - - // Look for cached kernel - if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) { - return it->second; - } - } - // Search for cached metal lib - MTL::Library* mtl_lib = get_library_(lib_name); - return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions); + return get_kernel( + base_name, default_library_, hash_name, func_consts, linked_functions); } void Device::set_residency_set(const MTL::ResidencySet* residency_set) { diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 660ba65e2..5bfcc6649 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -187,14 +187,16 @@ class Device { CommandEncoder& get_command_encoder(int index); void end_encoding(int index); - void register_library( - const std::string& lib_name, - const std::string& lib_path = ""); + MTL::Library* get_library( + const std::string& name, + const std::string& path = ""); MTL::Library* get_library( const std::string& name, const std::function& builder); + void clear_library(const std::string& name); + MTL::ComputePipelineState* get_kernel( const std::string& base_name, MTL::Library* mtl_lib, @@ -204,7 +206,6 @@ class Device { MTL::ComputePipelineState* get_kernel( const std::string& base_name, - const std::string& lib_name = "mlx", const std::string& hash_name = "", const MTLFCList& func_consts = {}, const std::vector& linked_functions = {}); @@ -258,10 +259,13 @@ class Device { std::unordered_map stream_map_; std::shared_mutex kernel_mtx_; - std::unordered_map kernel_map_; - std::shared_mutex library_mtx_; std::unordered_map library_map_; + MTL::Library* default_library_; + std::unordered_map< + MTL::Library*, + std::unordered_map> + library_kernels_; const MTL::ResidencySet* residency_set_{nullptr}; std::string arch_; int max_ops_per_buffer_; diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index 8da147971..b1478d33b 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -146,7 +146,7 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel( int, int, int) { - return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); + return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_steel_gemm_splitk_kernel( @@ -207,7 +207,7 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel( int, int, bool) { - return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); + return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_gemv_masked_kernel( @@ -259,7 +259,7 @@ MTL::ComputePipelineState* get_fft_kernel( const std::string& hash_name, const metal::MTLFCList& func_consts, const std::string&) { - return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); + return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_quantized_kernel( @@ -283,7 +283,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel( int, int, bool) { - return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); + return d.get_kernel(kernel_name, hash_name, func_consts); } } // namespace mlx::core diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index c53289828..d570bf3c0 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -172,7 +172,7 @@ void RMSNormVJP::eval_gpu( auto& compute_encoder = d.get_command_encoder(s.index); { - auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts); + auto kernel = d.get_kernel(op_name, hash_name, func_consts); MTL::Size grid_dims, group_dims; if (axis_size <= looped_limit) { @@ -395,7 +395,7 @@ void LayerNormVJP::eval_gpu( }; { - auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts); + auto kernel = d.get_kernel(op_name, hash_name, func_consts); MTL::Size grid_dims, group_dims; if (axis_size <= looped_limit) { diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 096d6b906..eef279d1d 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -73,7 +73,7 @@ void sdpa_full_self_attention_metal( std::string hash_name = kname.str(); auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); + auto kernel = d.get_kernel(base_name, hash_name, func_consts); compute_encoder.set_compute_pipeline_state(kernel); const int NQ = (qL + bq - 1) / bq; @@ -180,7 +180,7 @@ void sdpa_vector( // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts); + auto kernel = d.get_kernel(kname, hash_name, func_consts); compute_encoder.set_compute_pipeline_state(kernel); // Set its arguments @@ -281,7 +281,7 @@ void sdpa_vector_2pass( // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts); + auto kernel = d.get_kernel(kname, hash_name, func_consts); compute_encoder.set_compute_pipeline_state(kernel); diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 409aa2c89..849cbf83e 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -2,6 +2,7 @@ #include "mlx/primitives.h" #include "mlx/distributed/primitives.h" +#include "mlx/fast.h" #include "mlx/fast_primitives.h" #define NO_GPU_MULTI(func) \ @@ -155,6 +156,18 @@ NO_GPU_USE_FALLBACK(RoPE) NO_GPU(ScaledDotProductAttention) NO_GPU_MULTI(AffineQuantize) NO_GPU_MULTI(CustomKernel) + +MetalKernelFunction metal_kernel( + const std::string&, + const std::vector&, + const std::vector&, + const std::string&, + const std::string&, + bool ensure_row_contiguous, + bool atomic_outputs) { + throw std::runtime_error("[metal_kernel] No GPU back-end."); +} + } // namespace fast namespace distributed { diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 657c0aba8..210c7f729 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -1,10 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include -#include #include -#include -#include "mlx/backend/common/compiled.h" #include "mlx/fast.h" #include "mlx/fast_primitives.h" #include "mlx/ops.h" @@ -1027,303 +1024,4 @@ std::vector AffineQuantize::output_shapes( } } -std::string write_signature( - std::string func_name, - const std::string& header, - const std::string& source, - const std::vector& input_names, - const std::vector& inputs, - const std::vector& output_names, - const std::vector& output_dtypes, - const std::vector>& template_args, - const std::vector& attributes, - const std::vector& shape_infos, - bool atomic_outputs) { - std::string kernel_source; - kernel_source.reserve(header.size() + source.size() + 16384); - kernel_source += header; - // Auto-generate a function signature based on `template_args` - // and the dtype/shape of the arrays passed as `inputs`. - if (!template_args.empty()) { - kernel_source += "template <"; - int i = 0; - for (const auto& [name, arg] : template_args) { - std::string param_type; - if (std::holds_alternative(arg)) { - param_type = "int"; - } else if (std::holds_alternative(arg)) { - param_type = "bool"; - } else if (std::holds_alternative(arg)) { - param_type = "typename"; - } - if (i > 0) { - kernel_source += ", "; - } - kernel_source += param_type; - kernel_source += " "; - kernel_source += name; - i++; - } - kernel_source += ">\n"; - } - kernel_source += "[[kernel]] void "; - kernel_source += func_name; - kernel_source += "(\n"; - - int index = 0; - constexpr int max_constant_array_size = 8; - // Add inputs - for (int i = 0; i < inputs.size(); ++i) { - const auto& name = input_names[i]; - const auto& arr = inputs[i]; - auto dtype = get_type_string(arr.dtype()); - std::string location = - arr.size() < max_constant_array_size ? "constant" : "device"; - std::string ref = arr.ndim() == 0 ? "&" : "*"; - kernel_source += " const "; - kernel_source += location; - kernel_source += " "; - kernel_source += dtype; - kernel_source += ref; - kernel_source += " "; - kernel_source += name; - kernel_source += " [[buffer("; - kernel_source += std::to_string(index); - kernel_source += ")]],\n"; - index++; - // Add input shape, strides and ndim if present in the source - if (arr.ndim() > 0) { - if (shape_infos[i].shape) { - kernel_source += - (" const constant int* " + name + "_shape [[buffer(" + - std::to_string(index) + ")]],\n"); - index++; - } - if (shape_infos[i].strides) { - kernel_source += - (" const constant int64_t* " + name + "_strides [[buffer(" + - std::to_string(index) + ")]],\n"); - index++; - } - if (shape_infos[i].ndim) { - kernel_source += - (" const constant int& " + name + "_ndim [[buffer(" + - std::to_string(index) + ")]],\n"); - index++; - } - } - } - // Add outputs - for (int i = 0; i < output_names.size(); ++i) { - const auto& name = output_names[i]; - const auto& dtype = output_dtypes[i]; - kernel_source += " device "; - auto type_string = get_type_string(dtype); - if (atomic_outputs) { - kernel_source += "atomic<"; - } - kernel_source += type_string; - if (atomic_outputs) { - kernel_source += ">"; - } - kernel_source += "* "; - kernel_source += name; - kernel_source += " [[buffer("; - kernel_source += std::to_string(index); - kernel_source += ")]]"; - if (index < inputs.size() + output_names.size() - 1 || - attributes.size() > 0) { - kernel_source += ",\n"; - } else { - kernel_source += ") {\n"; - } - index++; - } - - index = 0; - for (const auto& attr : attributes) { - kernel_source += attr; - if (index < attributes.size() - 1) { - kernel_source += ",\n"; - } else { - kernel_source += ") {\n"; - } - index++; - } - kernel_source += source; - kernel_source += "\n}\n"; - return kernel_source; -} - -std::string write_template( - const std::vector>& template_args) { - std::ostringstream template_def; - template_def << "<"; - int i = 0; - for (const auto& [name, arg] : template_args) { - if (i > 0) { - template_def << ", "; - } - if (std::holds_alternative(arg)) { - template_def << std::get(arg); - } else if (std::holds_alternative(arg)) { - template_def << std::get(arg); - } else if (std::holds_alternative(arg)) { - template_def << get_type_string(std::get(arg)); - } - i++; - } - template_def << ">"; - return template_def.str(); -} - -MetalKernelFunction metal_kernel( - const std::string& name, - const std::vector& input_names, - const std::vector& output_names, - const std::string& source, - const std::string& header /* = "" */, - bool ensure_row_contiguous /* = true */, - bool atomic_outputs /* = false */) { - if (output_names.empty()) { - throw std::invalid_argument( - "[metal_kernel] Must specify at least one output."); - } - std::vector shape_infos; - for (auto& n : input_names) { - CustomKernelShapeInfo shape_info; - shape_info.shape = source.find(n + "_shape") != std::string::npos; - shape_info.strides = source.find(n + "_strides") != std::string::npos; - shape_info.ndim = source.find(n + "_ndim") != std::string::npos; - shape_infos.push_back(shape_info); - } - const std::vector> metal_attributes = { - {"dispatch_quadgroups_per_threadgroup", "uint"}, - {"dispatch_simdgroups_per_threadgroup", "uint"}, - {"dispatch_threads_per_threadgroup", "uint3"}, - {"grid_origin", "uint3"}, - {"grid_size", "uint3"}, - {"quadgroup_index_in_threadgroup", "uint"}, - {"quadgroups_per_threadgroup", "uint"}, - {"simdgroup_index_in_threadgroup", "uint"}, - {"simdgroups_per_threadgroup", "uint"}, - {"thread_execution_width", "uint"}, - {"thread_index_in_quadgroup", "uint"}, - {"thread_index_in_simdgroup", "uint"}, - {"thread_index_in_threadgroup", "uint"}, - {"thread_position_in_grid", "uint3"}, - {"thread_position_in_threadgroup", "uint3"}, - {"threadgroup_position_in_grid", "uint3"}, - {"threadgroups_per_grid", "uint3"}, - {"threads_per_grid", "uint3"}, - {"threads_per_simdgroup", "uint"}, - {"threads_per_threadgroup", "uint3"}, - }; - - std::vector attributes; - for (const auto& [attr, dtype] : metal_attributes) { - if (source.find(attr) != std::string::npos) { - attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]"); - } - } - - return [=, - shape_infos = std::move(shape_infos), - attributes = std::move(attributes)]( - const std::vector& inputs, - const std::vector& output_shapes, - const std::vector& output_dtypes, - std::tuple grid, - std::tuple threadgroup, - const std::vector>& - template_args = {}, - std::optional init_value = std::nullopt, - bool verbose = false, - StreamOrDevice s_ = {}) { - if (inputs.size() != input_names.size()) { - std::ostringstream msg; - msg << "[metal_kernel] Expected `inputs` to have size " - << input_names.size() << " but got size " << inputs.size() << "." - << std::endl; - throw std::invalid_argument(msg.str()); - } - if (output_shapes.size() != output_names.size()) { - std::ostringstream msg; - msg << "[metal_kernel] Expected `output_shapes` to have size " - << output_names.size() << " but got size " << output_shapes.size() - << "." << std::endl; - throw std::invalid_argument(msg.str()); - } - if (output_dtypes.size() != output_names.size()) { - std::ostringstream msg; - msg << "[metal_kernel] Expected `output_dtypes` to have size " - << output_names.size() << " but got size " << output_dtypes.size() - << "." << std::endl; - throw std::invalid_argument(msg.str()); - } - - auto s = to_stream(s_); - if (s.device != Device::gpu) { - throw std::invalid_argument("[metal_kernel] Only supports the GPU."); - } - - std::ostringstream func_name; - std::string template_def = ""; - std::string hash_key = ""; - if (!template_args.empty()) { - std::regex disallowed_chars("\\<|\\>|(, )"); - template_def = write_template(template_args); - hash_key = std::regex_replace(template_def, disallowed_chars, "_"); - hash_key.pop_back(); - } - func_name << "custom_kernel_" << name << hash_key; - std::string kernel_name = func_name.str(); - - std::string kernel_source = write_signature( - kernel_name, - header, - source, - input_names, - inputs, - output_names, - output_dtypes, - template_args, - attributes, - shape_infos, - atomic_outputs); - - if (!template_args.empty()) { - template_def = kernel_name + template_def; - kernel_source += "\ntemplate [[host_name(\""; - kernel_source += kernel_name; - kernel_source += "\")]] [[kernel]] decltype("; - kernel_source += template_def; - kernel_source += ") "; - kernel_source += template_def; - kernel_source += ";\n"; - } - - if (verbose) { - std::cout << "Generated source code for `" << name << "`:" << std::endl - << "```" << std::endl - << kernel_source << std::endl - << "```" << std::endl; - } - - return array::make_arrays( - std::move(output_shapes), - std::move(output_dtypes), - std::make_shared( - s, - std::move(kernel_name), - std::move(kernel_source), - grid, - threadgroup, - shape_infos, - ensure_row_contiguous, - init_value), - std::move(inputs)); - }; -} - } // namespace mlx::core::fast diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index 2c90a3755..59c2fc3ef 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -735,6 +735,41 @@ class TestFast(mlx_tests.MLXTestCase): )[0] self.assertEqual(out.item(), 2) + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_custom_kernel_caching(self): + def call_kernel(a: mx.array, source): + kernel = mx.fast.metal_kernel( + name="my_kernel", + input_names=["inp"], + output_names=["out"], + source=source, + ) + return kernel( + inputs=[a], + grid=(a.size, 1, 1), + threadgroup=(a.size, 1, 1), + output_shapes=[a.shape], + output_dtypes=[a.dtype], + stream=mx.gpu, + )[0] + + a = mx.random.normal(shape=(32,)) + + source = """ + uint elem = thread_position_in_grid.x; + out[elem] = 0.0; + """ + + out = call_kernel(a, source) + self.assertTrue(mx.array_equal(out, mx.zeros_like(out))) + + source = """ + uint elem = thread_position_in_grid.x; + out[elem] = 1.0; + """ + out = call_kernel(a, source) + self.assertTrue(mx.array_equal(out, mx.ones_like(out))) + if __name__ == "__main__": unittest.main()