diff --git a/docs/src/dev/custom_metal_kernels.rst b/docs/src/dev/custom_metal_kernels.rst index 12204d24a..04cede771 100644 --- a/docs/src/dev/custom_metal_kernels.rst +++ b/docs/src/dev/custom_metal_kernels.rst @@ -43,7 +43,8 @@ The full function signature will be generated using: * The keys and shapes/dtypes of ``inputs`` In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp`` so we will add ``const device float16_t* inp`` to the signature. - ``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience. + ``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present + in ``source``. * The keys and values of ``output_shapes`` and ``output_dtypes`` In the above, ``out`` is an ``mx.array`` of type ``mx.float16`` so we add ``device float16_t* out``. @@ -73,7 +74,7 @@ Putting this all together, the generated function signature for ``myexp`` is as template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float) custom_kernel_myexp_float; -You can print the generated code for a ``mx.fast.metal_kernel`` by passing ``verbose=True`` when you call it. +Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes. Using Shape/Strides ------------------- @@ -121,3 +122,292 @@ Let's convert ``myexp`` above to support arbitrarily strided arrays without rely a = a[::2] b = exp_elementwise(a) assert mx.allclose(b, mx.exp(a)) + +Complex Example +----------------------------- + +Let's implement a more complex example: ``grid_sample`` in ``"bilinear"`` mode. + +We'll start with the following MLX implementation using standard ops: + +.. code-block:: python + + def grid_sample_ref(x, grid): + N, H_in, W_in, _ = x.shape + ix = ((grid[..., 0] + 1) * W_in - 1) / 2 + iy = ((grid[..., 1] + 1) * H_in - 1) / 2 + + ix_nw = mx.floor(ix).astype(mx.int32) + iy_nw = mx.floor(iy).astype(mx.int32) + + ix_ne = ix_nw + 1 + iy_ne = iy_nw + + ix_sw = ix_nw + iy_sw = iy_nw + 1 + + ix_se = ix_nw + 1 + iy_se = iy_nw + 1 + + nw = (ix_se - ix) * (iy_se - iy) + ne = (ix - ix_sw) * (iy_sw - iy) + sw = (ix_ne - ix) * (iy - iy_ne) + se = (ix - ix_nw) * (iy - iy_nw) + + I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :] + I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :] + I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :] + I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :] + + mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1) + mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1) + mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1) + mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1) + + I_nw *= mask_nw[..., None] + I_ne *= mask_ne[..., None] + I_sw *= mask_sw[..., None] + I_se *= mask_se[..., None] + + output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se + + return output + +Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel`` +to write a fast GPU kernel for both the forward and backward passes. + +First we'll implement the forward pass as a fused kernel: + +.. code-block:: python + + @mx.custom_function + def grid_sample(x, grid): + + assert x.ndim == 4, "`x` must be 4D." + assert grid.ndim == 4, "`grid` must be 4D." + + B, _, _, C = x.shape + _, gN, gM, D = grid.shape + out_shape = (B, gN, gM, C) + + assert D == 2, "Last dim of `grid` must be size 2." + + source = """ + uint elem = thread_position_in_grid.x; + int H = x_shape[1]; + int W = x_shape[2]; + int C = x_shape[3]; + int gH = grid_shape[1]; + int gW = grid_shape[2]; + + int w_stride = C; + int h_stride = W * w_stride; + int b_stride = H * h_stride; + + uint grid_idx = elem / C * 2; + float ix = ((grid[grid_idx] + 1) * W - 1) / 2; + float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; + + int ix_nw = floor(ix); + int iy_nw = floor(iy); + + int ix_ne = ix_nw + 1; + int iy_ne = iy_nw; + + int ix_sw = ix_nw; + int iy_sw = iy_nw + 1; + + int ix_se = ix_nw + 1; + int iy_se = iy_nw + 1; + + T nw = (ix_se - ix) * (iy_se - iy); + T ne = (ix - ix_sw) * (iy_sw - iy); + T sw = (ix_ne - ix) * (iy - iy_ne); + T se = (ix - ix_nw) * (iy - iy_nw); + + int batch_idx = elem / C / gH / gW * b_stride; + int channel_idx = elem % C; + int base_idx = batch_idx + channel_idx; + + T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride]; + T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride]; + T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride]; + T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride]; + + I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0; + I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0; + I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0; + I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0; + + out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se; + """ + kernel = mx.fast.metal_kernel( + name="grid_sample", + source=source, + ) + outputs = kernel( + inputs={"x": x, "grid": grid}, + template={"T": x.dtype}, + output_shapes={"out": out_shape}, + output_dtypes={"out": x.dtype}, + grid=(np.prod(out_shape), 1, 1), + threadgroup=(256, 1, 1), + ) + return outputs["out"] + +For a reasonably sized input such as: + +.. code-block:: python + + x.shape = (8, 1024, 1024, 64) + grid.shape = (8, 256, 256, 2) + +On an M1 Max, we see a big performance improvement: + +``55.7ms -> 6.7ms => 8x speed up`` + +Grid Sample VJP +--------------- + +Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define +its custom vjp transform so MLX can differentiate it. + +The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so +requires a few extra ``mx.fast.metal_kernel`` features: + +* ``init_value=0`` + Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel. + +* ``atomic_outputs=True`` + Designate all of the kernel outputs as ``atomic`` in the function signature. + This means we can use Metal's ``atomic`` features to simultaneously update the ``x_grad`` and ``grid_grad`` arrays from multiple threadgroups. + See section 6.15 of the `Metal Shading Language Specification `_ for more details. + +We can then implement the backwards pass as follows: + +.. code-block:: python + + @grid_sample.vjp + def grid_sample_vjp(primals, cotangent, _): + x, grid = primals + B, _, _, C = x.shape + _, gN, gM, D = grid.shape + + assert D == 2, "Last dim of `grid` must be size 2." + + source = """ + uint elem = thread_position_in_grid.x; + int H = x_shape[1]; + int W = x_shape[2]; + int C = x_shape[3]; + // Pad C to the nearest larger simdgroup size multiple + int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup; + + int gH = grid_shape[1]; + int gW = grid_shape[2]; + + int w_stride = C; + int h_stride = W * w_stride; + int b_stride = H * h_stride; + + uint grid_idx = elem / C_padded * 2; + float ix = ((grid[grid_idx] + 1) * W - 1) / 2; + float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; + + int ix_nw = floor(ix); + int iy_nw = floor(iy); + + int ix_ne = ix_nw + 1; + int iy_ne = iy_nw; + + int ix_sw = ix_nw; + int iy_sw = iy_nw + 1; + + int ix_se = ix_nw + 1; + int iy_se = iy_nw + 1; + + T nw = (ix_se - ix) * (iy_se - iy); + T ne = (ix - ix_sw) * (iy_sw - iy); + T sw = (ix_ne - ix) * (iy - iy_ne); + T se = (ix - ix_nw) * (iy - iy_nw); + + int batch_idx = elem / C_padded / gH / gW * b_stride; + int channel_idx = elem % C_padded; + int base_idx = batch_idx + channel_idx; + + T gix = T(0); + T giy = T(0); + if (channel_idx < C) { + int cot_index = elem / C_padded * C + channel_idx; + T cot = cotangent[cot_index]; + if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) { + int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride; + atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed); + + T I_nw = x[offset]; + gix -= I_nw * (iy_se - iy) * cot; + giy -= I_nw * (ix_se - ix) * cot; + } + if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) { + int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride; + atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed); + + T I_ne = x[offset]; + gix += I_ne * (iy_sw - iy) * cot; + giy -= I_ne * (ix - ix_sw) * cot; + } + if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) { + int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride; + atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed); + + T I_sw = x[offset]; + gix -= I_sw * (iy - iy_ne) * cot; + giy += I_sw * (ix_ne - ix) * cot; + } + if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) { + int offset = base_idx + iy_se * h_stride + ix_se * w_stride; + atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed); + + T I_se = x[offset]; + gix += I_se * (iy - iy_nw) * cot; + giy += I_se * (ix - ix_nw) * cot; + } + } + + T gix_mult = W / 2; + T giy_mult = H / 2; + + // Reduce across each simdgroup first. + // This is much faster than relying purely on atomics. + gix = simd_sum(gix); + giy = simd_sum(giy); + + if (thread_index_in_simdgroup == 0) { + atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed); + atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed); + } + """ + kernel = mx.fast.metal_kernel( + name="grid_sample_grad", + source=source, + atomic_outputs=True, + ) + # pad the output channels to simd group size + # so that our `simd_sum`s don't overlap. + simdgroup_size = 32 + C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size + grid_size = B * gN * gM * C_padded + outputs = kernel( + inputs={"x": x, "grid": grid, "cotangent": cotangent}, + template={"T": x.dtype}, + output_shapes={"x_grad": x.shape, "grid_grad": grid.shape}, + output_dtypes={"x_grad": x.dtype, "grid_grad": x.dtype}, + grid=(grid_size, 1, 1), + threadgroup=(256, 1, 1), + init_value=0, + ) + return outputs["x_grad"], outputs["grid_grad"] + +There's an even larger speed up for the vjp: + +``676.4ms -> 16.7ms => 40x speed up`` diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index b7c3f7ebf..92cb55dcd 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -12,12 +12,17 @@ void CustomKernel::eval_gpu( std::vector& outputs) { auto& s = stream(); + std::vector copies; + for (auto& out : outputs) { out.set_data(allocator::malloc_or_wait(out.nbytes())); + if (init_value_) { + array init = array(init_value_.value(), out.dtype()); + copy_gpu(init, out, CopyType::Scalar, s); + copies.push_back(init); + } } - std::vector copies; - auto check_input = [&copies, &s, this](const array& x) -> const array { bool no_copy = x.flags().row_contiguous; if (!ensure_row_contiguous_ || no_copy) { diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 3a7c80f8b..186d2081f 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -949,6 +949,7 @@ void write_signature( std::map& output_dtypes, std::optional> template_args, std::vector& shape_infos, + bool atomic_outputs, std::ostringstream& kernel_source) { // Auto-generate a function signature based on `template_args` // and the dtype/shape of the arrays passed as `inputs`. @@ -1042,8 +1043,14 @@ void write_signature( } // Add outputs for (const auto& [name, dtype] : output_dtypes) { - kernel_source << " device " << get_type_string(dtype) << "* " << name - << " [[buffer(" << index << ")]]"; + kernel_source << " device "; + auto type_string = get_type_string(dtype); + if (atomic_outputs) { + kernel_source << "atomic<" << type_string << ">"; + } else { + kernel_source << type_string; + } + kernel_source << "* " << name << " [[buffer(" << index << ")]]"; if (index < inputs.size() + output_shapes.size() - 1 || attrs.size() > 0) { kernel_source << "," << std::endl; } else { @@ -1094,6 +1101,7 @@ std::map MetalKernel::operator()( std::tuple grid, std::tuple threadgroup, std::optional> template_args, + std::optional init_value, bool verbose, StreamOrDevice s_) { validate_output_shapes(output_shapes, output_dtypes); @@ -1129,6 +1137,7 @@ std::map MetalKernel::operator()( output_dtypes, template_args, shape_infos, + atomic_outputs_, kernel_source); if (needs_template) { @@ -1174,7 +1183,8 @@ std::map MetalKernel::operator()( grid, threadgroup, shape_infos, - ensure_row_contiguous_), + ensure_row_contiguous_, + init_value), in_arrs); int i = 0; diff --git a/mlx/fast.h b/mlx/fast.h index 4b6e1e2c1..75ac8759a 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -71,10 +71,12 @@ class MetalKernel { MetalKernel( const std::string& name, const std::string& source, - bool ensure_row_contiguous) + bool ensure_row_contiguous, + bool atomic_outputs) : name_(name), source_(source), - ensure_row_contiguous_(ensure_row_contiguous) {} + ensure_row_contiguous_(ensure_row_contiguous), + atomic_outputs_(atomic_outputs) {} std::map operator()( std::map& inputs, @@ -84,6 +86,7 @@ class MetalKernel { std::tuple threadgroup, std::optional> template_args = std::nullopt, + std::optional init_value = std::nullopt, bool verbose = false, StreamOrDevice s = {}); @@ -91,5 +94,6 @@ class MetalKernel { std::string name_; std::string source_; bool ensure_row_contiguous_ = true; + bool atomic_outputs_ = false; }; } // namespace mlx::core::fast diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 0039ff01a..1d01610f3 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -1,5 +1,7 @@ // Copyright © 2024 Apple Inc. +#include + #include "mlx/primitives.h" namespace mlx::core::fast { @@ -257,14 +259,16 @@ class CustomKernel : public Primitive { std::tuple grid, std::tuple threadgroup, std::vector shape_infos, - bool ensure_row_contiguous) + bool ensure_row_contiguous, + std::optional init_value) : Primitive(stream), source_(source), name_(name), grid_(grid), threadgroup_(threadgroup), shape_infos_(shape_infos), - ensure_row_contiguous_(ensure_row_contiguous) {} + ensure_row_contiguous_(ensure_row_contiguous), + init_value_(init_value) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override { @@ -283,6 +287,7 @@ class CustomKernel : public Primitive { std::tuple threadgroup_; std::vector shape_infos_; bool ensure_row_contiguous_; + std::optional init_value_; }; } // namespace mlx::core::fast diff --git a/python/src/fast.cpp b/python/src/fast.cpp index e9178ac2b..863a65ec1 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -200,10 +200,11 @@ void init_fast(nb::module_& parent_module) { A jit-compiled custom Metal kernel defined from a source string. )pbdoc") .def( - nb::init(), + nb::init(), "name"_a, "source"_a, "ensure_row_contiguous"_a = true, + "atomic_outputs"_a = false, R"pbdoc( Initialize a metal_kernel. @@ -215,6 +216,8 @@ void init_fast(nb::module_& parent_module) { used when the kernel is called. ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous before the kernel runs. Default: ``True``. + atomic_outputs (bool): Whether to use atomic outputs in the function signature + e.g. ``device atomic``. Default: ``False``. Returns: Callable ``metal_kernel``. @@ -256,6 +259,7 @@ void init_fast(nb::module_& parent_module) { std::tuple grid, std::tuple threadgroup, std::optional> template_args_, + std::optional init_value, bool verbose, StreamOrDevice s) { std::map inputs; @@ -289,6 +293,7 @@ void init_fast(nb::module_& parent_module) { grid, threadgroup, template_args, + init_value, verbose, s); }, @@ -299,10 +304,11 @@ void init_fast(nb::module_& parent_module) { "grid"_a, "threadgroup"_a, "template"_a = nb::none(), + "init_value"_a = nb::none(), "verbose"_a = false, "stream"_a = nb::none(), nb::sig( - "def __call__(self, *, inputs: Mapping[str, Union[scalar, array]], output_shapes: Mapping[str, Sequence[int]], output_dtypes: Mapping[str, Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[Mapping[str, Union[bool, int, Dtype]]] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"), + "def __call__(self, *, inputs: Mapping[str, Union[scalar, array]], output_shapes: Mapping[str, Sequence[int]], output_dtypes: Mapping[str, Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[Mapping[str, Union[bool, int, Dtype]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"), R"pbdoc( Run the kernel. @@ -316,9 +322,11 @@ void init_fast(nb::module_& parent_module) { grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with. threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use. template (Mapping[str, Union[bool, int, Dtype]], optional): Template arguments. - These will be added as template arguments to the kernel definition. + These will be added as template arguments to the kernel definition. Default: ``None``. + init_value (float, optional): Optional value to use to initialize all of the output arrays. + By default, output arrays are uninitialized. Default: ``None``. verbose (bool, optional): Whether to print the full generated source code of the kernel - when it is run. + when it is run. Default: ``False``. stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``. Returns: