From 2419edd5b27e4146d36a9c4ddf94b3e7c9901936 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 18 Nov 2024 19:52:00 -0800 Subject: [PATCH] Faster indexing math in a few kernels (#1589) * wip: faster compiled kernels * faster general unary with uint specialization * index type in compiled, unary, binary, ternary, copy * fix jit * jit fix * specialize gather + scatter * nit in docs --- docs/src/usage/function_transforms.rst | 4 +- mlx/backend/common/compiled_cpu.cpp | 2 +- mlx/backend/metal/binary.cpp | 43 ++-- mlx/backend/metal/compiled.cpp | 244 +++++++++++++-------- mlx/backend/metal/copy.cpp | 74 ++++--- mlx/backend/metal/indexing.cpp | 70 +++--- mlx/backend/metal/jit/indexing.h | 8 +- mlx/backend/metal/jit_kernels.cpp | 159 ++++++++------ mlx/backend/metal/kernels/binary.h | 37 ++-- mlx/backend/metal/kernels/binary.metal | 27 ++- mlx/backend/metal/kernels/binary_two.h | 37 ++-- mlx/backend/metal/kernels/binary_two.metal | 27 ++- mlx/backend/metal/kernels/copy.h | 51 +++-- mlx/backend/metal/kernels/copy.metal | 32 +-- mlx/backend/metal/kernels/gather.h | 24 +- mlx/backend/metal/kernels/indexing.h | 2 +- mlx/backend/metal/kernels/scatter.h | 21 +- mlx/backend/metal/kernels/ternary.h | 40 ++-- mlx/backend/metal/kernels/ternary.metal | 17 +- mlx/backend/metal/kernels/unary.h | 16 +- mlx/backend/metal/kernels/unary.metal | 12 +- mlx/backend/metal/kernels/utils.h | 93 ++++---- mlx/backend/metal/ternary.cpp | 46 ++-- mlx/backend/metal/unary.cpp | 17 +- mlx/backend/metal/utils.h | 11 + 25 files changed, 630 insertions(+), 484 deletions(-) diff --git a/docs/src/usage/function_transforms.rst b/docs/src/usage/function_transforms.rst index 9769fceaa..045c36c93 100644 --- a/docs/src/usage/function_transforms.rst +++ b/docs/src/usage/function_transforms.rst @@ -184,8 +184,8 @@ Let's time these two different versions: print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100)) print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100)) -On an M1 Max the naive version takes in total ``0.390`` seconds whereas the -vectorized version takes only ``0.025`` seconds, more than ten times faster. +On an M1 Max the naive version takes in total ``5.639`` seconds whereas the +vectorized version takes only ``0.024`` seconds, more than 200 times faster. Of course, this operation is quite contrived. A better approach is to simply do ``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy. diff --git a/mlx/backend/common/compiled_cpu.cpp b/mlx/backend/common/compiled_cpu.cpp index 0923398c7..cba6325c8 100644 --- a/mlx/backend/common/compiled_cpu.cpp +++ b/mlx/backend/common/compiled_cpu.cpp @@ -279,7 +279,7 @@ void Compiled::eval_cpu( // Figure out which kernel we are using auto& shape = outputs[0].shape(); - bool contiguous = compiled_check_contiguity(inputs, shape); + auto contiguous = compiled_check_contiguity(inputs, shape); // Handle all broadcasting and collect function input arguments std::vector args; diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index 66a31c922..585183967 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -22,37 +22,37 @@ std::string get_kernel_name( BinaryOpType bopt, const std::string& op, const array& a, - bool use_2d, + bool large, int ndim, int work_per_thread) { - std::ostringstream kname; + std::string kname; switch (bopt) { case BinaryOpType::ScalarScalar: - kname << "ss"; + kname = "ss"; break; case BinaryOpType::ScalarVector: - kname << (use_2d ? "sv2" : "sv"); + kname = (large ? "sv2" : "sv"); break; case BinaryOpType::VectorScalar: - kname << (use_2d ? "vs2" : "vs"); + kname = (large ? "vs2" : "vs"); break; case BinaryOpType::VectorVector: - kname << (use_2d ? "vv2" : "vv"); + kname = (large ? "vv2" : "vv"); break; case BinaryOpType::General: - kname << "g"; + kname = "g"; if (ndim <= 3) { - kname << ndim; + kname += std::to_string(ndim); } else { - kname << "n"; - if (work_per_thread > 1) { - kname << work_per_thread; - } + concatenate(kname, "n", std::to_string(work_per_thread)); + } + if (large) { + kname += "large"; } break; } - kname << "_" << op << type_to_name(a); - return kname.str(); + concatenate(kname, "_", op, type_to_name(a)); + return kname; } void binary_op_gpu_inplace( @@ -81,11 +81,16 @@ void binary_op_gpu_inplace( }; auto [shape, strides_a, strides_b, strides_out] = maybe_collapse(); - bool use_2d = out.data_size() > UINT32_MAX; + bool large = out.data_size() > UINT32_MAX; auto ndim = shape.size(); - int work_per_thread = (bopt == BinaryOpType::General) ? 4 : 1; + int work_per_thread; + if (bopt == BinaryOpType::General) { + work_per_thread = large ? 4 : 2; + } else { + work_per_thread = 1; + } std::string kernel_name = - get_kernel_name(bopt, op, a, use_2d, shape.size(), work_per_thread); + get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread); auto& d = metal::device(s.device); auto kernel = outputs.size() == 2 @@ -141,8 +146,8 @@ void binary_op_gpu_inplace( thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) + : MTL::Size(nthreads, 1, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index a5a8805fb..ffe5af6e5 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -1,5 +1,6 @@ // Copyright © 2023-2024 Apple Inc. - +#include +#include //TODO #include #include "mlx/backend/common/compiled.h" @@ -11,12 +12,12 @@ #include "mlx/primitives.h" #include "mlx/utils.h" +using namespace fmt::literals; + namespace mlx::core { -constexpr int WORK_PER_THREAD = 4; - inline void build_kernel( - std::ostream& os, + std::string& os, const std::string& kernel_name, const std::vector& inputs, const std::vector& outputs, @@ -41,8 +42,8 @@ inline void build_kernel( int cnt = 0; // Start the kernel - os << "[[host_name(\"" << kernel_name << "\")]]\n" - << "[[kernel]] void " << kernel_name << "(\n"; + os += fmt::format( + "[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name); // Add the input arguments for (auto& x : inputs) { @@ -54,51 +55,61 @@ inline void build_kernel( } // Scalars and contiguous need no strides - if (is_scalar(x) || contiguous) { - os << " device const " << get_type_string(x.dtype()) << "* " << xname - << " [[buffer(" << cnt++ << ")]],\n"; - } else { + if (!is_scalar(x) && !contiguous) { add_indices = true; - os << " device const " << get_type_string(x.dtype()) << "* " << xname - << " [[buffer(" << cnt++ << ")]],\n"; } + os += fmt::format( + " device const {0}* {1} [[buffer({2})]],\n", + get_type_string(x.dtype()), + xname, + cnt++); } if (add_indices) { - os << " constant const size_t* in_strides [[buffer(" << cnt++ - << ")]],\n"; + os += fmt::format( + " constant const size_t* in_strides [[buffer({0})]],\n", cnt++); } // Add the output arguments for (auto& x : outputs) { - os << " device " << get_type_string(x.dtype()) << "* " - << namer.get_name(x) << " [[buffer(" << cnt++ << ")]],\n"; + os += fmt::format( + " device {0}* {1} [[buffer({2})]],\n", + get_type_string(x.dtype()), + namer.get_name(x), + cnt++); } // Add output strides and shape to extract the indices. if (!contiguous) { - os << " constant const size_t* output_strides [[buffer(" << cnt++ - << ")]],\n" - << " constant const int* output_shape [[buffer(" << cnt++ << ")]],\n"; + os += fmt::format( + " constant const size_t* output_strides [[buffer({0})]],\n", cnt++); + os += fmt::format( + " constant const int* output_shape [[buffer({0})]],\n", cnt++); } if (dynamic_dims) { - os << " constant const int& ndim [[buffer(" << cnt++ << ")]],\n"; + os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++); } // The thread index in the whole grid - os << " uint3 pos [[thread_position_in_grid]],\n" - << " uint3 grid [[threads_per_grid]]) {\n"; + os += " uint3 pos [[thread_position_in_grid]],\n"; + os += " uint3 grid [[threads_per_grid]]) {\n"; - if (use_big_index) { + std::string idx_type = use_big_index ? "size_t" : "uint"; + if (contiguous && use_big_index) { // This is only used for contiguous kernels which don't have // a third grid dimension - os << " size_t index = pos.x + grid.x * size_t(pos.y);\n"; + os += " size_t index = pos.x + grid.x * size_t(pos.y);\n"; } else if (work_per_thread > 1) { - os << " constexpr int N_ = " << std::to_string(work_per_thread) << ";\n" - << " int xshape = output_shape[" - << (dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)) << "];\n" - << " size_t index = N_ * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n"; + os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread); + os += fmt::format( + " int xshape = output_shape[{0}];\n", + dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)); + os += fmt::format( + " {0} index = N_ * pos.x + xshape * (pos.y + {0}(grid.y) * pos.z);\n", + idx_type); } else { - os << " size_t index = pos.x + grid.x * (pos.y + size_t(grid.y) * pos.z);\n"; + os += fmt::format( + " {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n", + idx_type); } // Read constant / contiguous inputs in tmps @@ -109,16 +120,19 @@ inline void build_kernel( if (is_constant(x)) { auto type_str = get_type_string(x.dtype()); - os << " auto tmp_" << xname << " = static_cast<" - << get_type_string(x.dtype()) << ">("; - print_constant(os, x); - os << ");\n"; + std::ostringstream ss; + print_constant(ss, x); + os += fmt::format( + " auto tmp_{0} = static_cast<{1}>({2});\n", + xname, + get_type_string(x.dtype()), + ss.str()); } else if (is_scalar(x)) { - os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " - << xname << "[0];\n"; + os += fmt::format( + " {0} tmp_{1} = {1}[0];\n", get_type_string(x.dtype()), xname); } else if (contiguous) { - os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " - << xname << "[index];\n"; + os += fmt::format( + " {0} tmp_{1} = {1}[index];\n", get_type_string(x.dtype()), xname); } else { nc_inputs.push_back(x); } @@ -127,83 +141,98 @@ inline void build_kernel( // Initialize the indices for non-contiguous inputs for (int i = 0; i < nc_inputs.size(); ++i) { auto& xname = namer.get_name(nc_inputs[i]); + os += fmt::format(" {0} index_{1} = ", idx_type, xname); if (ndim == 1) { int offset = i * ndim; - os << " size_t index_" << xname << " = elem_to_loc_1(pos.x, " - << "in_strides[" << offset << "]);\n"; + os += fmt::format( + "elem_to_loc_1(pos.x, in_strides[{0}]);\n", offset); } else if (ndim == 2) { int offset = i * ndim; - os << " size_t index_" << xname << " = elem_to_loc_2({pos.x, pos.y}, " - << "in_strides + " << offset << ");\n"; + os += fmt::format( + "elem_to_loc_2({{pos.x, pos.y}}, in_strides + {1});\n", + idx_type, + offset); } else if (ndim == 3) { int offset = i * ndim; - os << " size_t index_" << xname << " = elem_to_loc_3(pos, " - << "in_strides + " << offset << ");\n"; + os += fmt::format( + "elem_to_loc_3(pos, in_strides + {1});\n", + idx_type, + offset); } else if (!dynamic_dims) { - int offset = i * ndim; - os << " size_t index_" << xname << " = N_ * pos.x * in_strides[" - << offset + ndim - 1 << "]" - << " + pos.y * in_strides[" << offset + ndim - 2 << "];\n"; + int offset = (i + 1) * ndim; + os += fmt::format( + "N_ * pos.x * {0}(in_strides[{1}]) + pos.y * {0}(in_strides[{2}]);\n", + idx_type, + offset - 1, + offset - 2); } else { - os << " size_t index_" << xname << " = N_ * pos.x * in_strides[ndim * " - << i << " + ndim - 1]" - << " + pos.y * in_strides[ndim * " << i << " + ndim - 2];\n"; + os += fmt::format( + "N_ * pos.x * {0}(in_strides[ndim * {1} + ndim - 1]) + pos.y * {0}(in_strides[ndim * {1} + ndim - 2]);\n", + idx_type, + i); } } + if (!nc_inputs.empty() && (ndim > 3 || dynamic_dims)) { - os << " uint zpos = pos.z;\n"; + os += " uint zpos = pos.z;\n"; if (dynamic_dims) { - os << " for (int d = ndim - 3; d >= 0; --d) {\n"; + os += " for (int d = ndim - 3; d >= 0; --d) {\n"; } else { - os << " for (int d = " << ndim - 3 << "; d >= 0; --d) {\n"; + os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3); } - os << " uint l = zpos % output_shape[d];\n"; + os += " uint l = zpos % output_shape[d];\n"; for (int i = 0; i < nc_inputs.size(); ++i) { auto& xname = namer.get_name(nc_inputs[i]); - os << " index_" << xname << " += "; + os += fmt::format(" index_{0} += ", xname); if (dynamic_dims) { - os << "l * in_strides[" << i << " * ndim + d];\n"; + os += + fmt::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i); } else { - os << "l * in_strides[" << i * ndim << " + d];\n"; + os += + fmt::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim); } } - os << " zpos /= output_shape[d];\n }\n"; + os += " zpos /= output_shape[d];\n }\n"; } // Open per-thread loop if (work_per_thread > 1) { - os << " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n"; + os += + " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n"; } // Read non-contiguous inputs into tmps for (int i = 0; i < nc_inputs.size(); ++i) { auto& x = nc_inputs[i]; auto& xname = namer.get_name(x); - os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " - << xname << "[index_" << xname << "];\n"; + os += fmt::format( + " {0} tmp_{1} = {1}[index_{1}];\n", get_type_string(x.dtype()), xname); } // Actually write the computation for (auto& x : tape) { - os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x) - << " = "; + os += fmt::format( + " {0} tmp_{1} = ", get_type_string(x.dtype()), namer.get_name(x)); if (is_static_cast(x.primitive())) { - os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_" - << namer.get_name(x.inputs()[0]) << ");\n"; + os += fmt::format( + "static_cast<{0}>(tmp_{1});\n", + get_type_string(x.dtype()), + namer.get_name(x.inputs()[0])); } else { - x.primitive().print(os); - os << "()("; + std::ostringstream ss; + x.primitive().print(ss); + os += ss.str(); + os += "()("; for (int i = 0; i < x.inputs().size() - 1; i++) { - os << "tmp_" << namer.get_name(x.inputs()[i]) << ", "; + os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i])); } - os << "tmp_" << namer.get_name(x.inputs().back()) << ");\n"; + os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back())); } } // Write the outputs from tmps for (auto& x : outputs) { - os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x) - << ";\n"; + os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x)); } // Increment indices and close per thread loop if (work_per_thread > 1) { @@ -211,18 +240,18 @@ inline void build_kernel( auto& x = nc_inputs[i]; auto& xname = namer.get_name(x); if (!dynamic_dims) { - os << " index_" << xname << " += " - << "in_strides[" << i * ndim + ndim - 1 << "];\n"; + os += fmt::format( + " index_{0} += in_strides[{1}];\n", xname, i * ndim + ndim - 1); } else { - os << " index_" << xname << " += " - << "in_strides[" << i << " * ndim + ndim - 1];\n"; + os += fmt::format( + " index_{0} += in_strides[{1} * ndim + ndim - 1];\n", xname, i); } } - os << " index++;\n }\n"; + os += " index++;\n }\n"; } // Finish the kernel - os << "}\n"; + os += "}\n"; if (cnt > 31) { std::ostringstream msg; @@ -246,9 +275,9 @@ void Compiled::eval_gpu( auto& s = stream(); auto& d = metal::device(s.device); auto lib = d.get_library(kernel_lib_, [&]() { - std::ostringstream kernel; - kernel << metal::utils() << metal::unary_ops() << metal::binary_ops() - << metal::ternary_ops(); + std::string kernel = metal::utils(); + concatenate( + kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops()); build_kernel( kernel, kernel_lib_ + "_contiguous", @@ -261,7 +290,7 @@ void Compiled::eval_gpu( /* dynamic_dims = */ false); build_kernel( kernel, - kernel_lib_ + "_contiguous_big", + kernel_lib_ + "_contiguous_large", inputs_, outputs_, tape_, @@ -282,7 +311,21 @@ void Compiled::eval_gpu( /* ndim = */ i, /* dynamic_dims = */ false, /* use_big_index = */ false, - /* work_per_thread = */ i > 3 ? WORK_PER_THREAD : 1); + /* work_per_thread = */ i > 3 ? 2 : 1); + if (i > 1) { + build_kernel( + kernel, + kernel_lib_ + "_strided_" + std::to_string(i) + "_large", + inputs_, + outputs_, + tape_, + constant_ids_, + /* contiguous = */ false, + /* ndim = */ i, + /* dynamic_dims = */ false, + /* use_big_index = */ true, + /* work_per_thread = */ i > 3 ? 4 : 1); + } } build_kernel( kernel, @@ -295,13 +338,25 @@ void Compiled::eval_gpu( /* ndim = */ 0, /* dynamic_dims = */ true, /* use_big_index = */ false, - /* work_per_thread = */ WORK_PER_THREAD); - return kernel.str(); + /* work_per_thread = */ 2); + build_kernel( + kernel, + kernel_lib_ + "_strided_dynamic_large", + inputs_, + outputs_, + tape_, + constant_ids_, + /* contiguous = */ false, + /* ndim = */ 0, + /* dynamic_dims = */ true, + /* use_big_index = */ true, + /* work_per_thread = */ 4); + return kernel; }); // Figure out which kernel we are using auto& output_shape = outputs[0].shape(); - bool contiguous = compiled_check_contiguity(inputs, output_shape); + auto contiguous = compiled_check_contiguity(inputs, output_shape); // Collapse contiguous dims to route to a faster kernel if possible. Also // handle all broadcasting. @@ -349,13 +404,19 @@ void Compiled::eval_gpu( collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX); } - bool use_2d = false; + bool large; if (contiguous) { size_t max_size = 0; for (auto& in : inputs) { max_size = std::max(max_size, in.data_size()); } - use_2d = (max_size > UINT32_MAX); + large = (max_size > UINT32_MAX); + } else { + size_t max_size = 0; + for (auto& o : outputs) { + max_size = std::max(max_size, o.size()); + } + large = (max_size > UINT32_MAX); } // Get the kernel from the lib @@ -368,8 +429,9 @@ void Compiled::eval_gpu( } else { kernel_name += std::to_string(shape.size()); } - } else if (use_2d) { - kernel_name += "_big"; + } + if (large) { + kernel_name += "_large"; } auto kernel = d.get_kernel(kernel_name, lib); auto& compute_encoder = d.get_command_encoder(s.index); @@ -422,7 +484,7 @@ void Compiled::eval_gpu( MTL::Size group_dims( std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); - MTL::Size grid_dims = use_2d + MTL::Size grid_dims = large ? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides()) : MTL::Size(nthreads, 1, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); @@ -430,7 +492,7 @@ void Compiled::eval_gpu( size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; size_t rest = outputs[0].size() / (dim0 * dim1); - int work_per_thread = ndim > 3 ? WORK_PER_THREAD : 1; + int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1; dim0 = (dim0 + work_per_thread - 1) / work_per_thread; NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); int pow2; diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index f2f31cd1f..f7b2fd865 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -74,40 +74,42 @@ void copy_gpu_inplace( }; auto [shape, strides_in_, strides_out_] = maybe_collapse(); int ndim = shape.size(); - - bool use_2d = out.data_size() > UINT32_MAX; + bool large; + if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { + // Allow for negative strides + large = out.data_size() > INT32_MAX; + } else { + large = out.data_size() > UINT32_MAX; + } auto& d = metal::device(s.device); int work_per_thread = 1; std::string kernel_name; - { - std::ostringstream kname; - switch (ctype) { - case CopyType::Scalar: - kname << (use_2d ? "s2" : "s"); - break; - case CopyType::Vector: - kname << (use_2d ? "v2" : "v"); - break; - case CopyType::General: - kname << "g"; - break; - case CopyType::GeneralGeneral: - kname << "gg"; - break; - } - if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { - if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) { - kname << shape.size(); - } else { - work_per_thread = 4; - kname << "n4"; - } - } - kname << "_copy"; - kname << type_to_name(in) << type_to_name(out); - kernel_name = kname.str(); + switch (ctype) { + case CopyType::Scalar: + kernel_name = (large ? "s2" : "s"); + break; + case CopyType::Vector: + kernel_name = (large ? "v2" : "v"); + break; + case CopyType::General: + kernel_name = "g"; + break; + case CopyType::GeneralGeneral: + kernel_name = "gg"; + break; } - + if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { + if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) { + kernel_name += std::to_string(shape.size()); + } else { + work_per_thread = large ? 4 : 2; + concatenate(kernel_name, "n", std::to_string(work_per_thread)); + } + if (large) { + kernel_name += "large"; + } + } + concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out)); auto kernel = get_copy_kernel(d, kernel_name, in, out); auto& compute_encoder = d.get_command_encoder(s.index); @@ -159,8 +161,8 @@ void copy_gpu_inplace( thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) + : MTL::Size(nthreads, 1, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } } @@ -193,9 +195,9 @@ void fill_gpu(const array& val, array& out, const Stream& s) { return; } out.set_data(allocator::malloc_or_wait(out.nbytes())); - bool use_2d = out.data_size() > UINT32_MAX; + bool large = out.data_size() > UINT32_MAX; auto& d = metal::device(s.device); - std::string kernel_name = std::string(use_2d ? "s2" : "s") + "_copy" + + std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" + type_to_name(val) + type_to_name(out); auto kernel = get_copy_kernel(d, kernel_name, val, out); auto& compute_encoder = d.get_command_encoder(s.index); @@ -210,8 +212,8 @@ void fill_gpu(const array& val, array& out, const Stream& s) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) + : MTL::Size(nthreads, 1, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index b7cb0ab7a..bea3e8e57 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -53,27 +53,31 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { int idx_ndim = nidx ? inputs[1].ndim() : 0; size_t ndim = src.ndim(); - std::string lib_name; - std::string kernel_name; + bool large_index = nidx && inputs[1].size() > UINT32_MAX; + bool large_src = src.size() > UINT32_MAX; + bool large_out = out.size() > UINT32_MAX; + bool large = large_index || large_src || large_out; + std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; - { - std::ostringstream kname; - kname << "gather" << type_to_name(out) << idx_type_name << "_" << nidx - << "_" << idx_ndim; - lib_name = kname.str(); - kernel_name = lib_name; - } + std::string kernel_name = fmt::format( + "gather{0}{1}_{2}_{3}_{4}", + type_to_name(out), + idx_type_name, + nidx, + idx_ndim, + large ? "size_t" : "uint"); + std::string lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { - std::ostringstream kernel_source; - kernel_source << metal::utils() << metal::gather(); + std::string kernel_source = metal::utils(); + kernel_source += metal::gather(); std::string out_type_str = get_type_string(out.dtype()); std::string idx_type_str = nidx ? get_type_string(inputs[1].dtype()) : "bool"; auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx); // Index dimension specializations - kernel_source << fmt::format( + kernel_source += fmt::format( gather_kernels, type_to_name(out) + idx_type_name, out_type_str, @@ -81,8 +85,9 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { nidx, idx_args, idx_arr, - idx_ndim); - return kernel_source.str(); + idx_ndim, + large ? "size_t" : "uint"); + return kernel_source; }); auto& compute_encoder = d.get_command_encoder(s.index); @@ -209,8 +214,6 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { nwork = 32; } - std::string lib_name; - std::string kernel_name; std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; std::string op_name; switch (reduce_type_) { @@ -231,18 +234,24 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { break; } auto upd_contig = upd.flags().row_contiguous; - { - std::ostringstream kname; - kname << "scatter" << type_to_name(out) << idx_type_name; - kname << "_" << op_name << "_" << nidx << "_" - << (upd_contig ? "updc_true" : "updc_false") << "_nwork" << nwork; - lib_name = kname.str(); - kernel_name = kname.str(); - } + bool large_out = out.size() > UINT32_MAX; + bool large_idx = nidx && (inputs[1].size() > UINT32_MAX); + bool large_upd = upd.size() > UINT32_MAX; + bool large = large_out || large_idx || large_upd; + std::string kernel_name = fmt::format( + "scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}", + type_to_name(out), + idx_type_name, + op_name, + nidx, + upd_contig ? "updc_true" : "updc_false", + nwork, + large ? "size_t" : "uint"); + std::string lib_name = kernel_name; + auto lib = d.get_library(lib_name, [&]() { - std::ostringstream kernel_source; - kernel_source << metal::utils() << metal::reduce_utils() - << metal::scatter(); + std::string kernel_source = metal::utils(); + concatenate(kernel_source, metal::reduce_utils(), metal::scatter()); std::string out_type_str = get_type_string(out.dtype()); std::string idx_type_str = @@ -270,7 +279,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { } auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx); - kernel_source << fmt::format( + kernel_source += fmt::format( scatter_kernels, type_to_name(out) + idx_type_name + "_" + op_name, out_type_str, @@ -280,8 +289,9 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { idx_args, idx_arr, upd_contig, - nwork); - return kernel_source.str(); + nwork, + large ? "size_t" : "uint"); + return kernel_source; }); auto& compute_encoder = d.get_command_encoder(s.index); diff --git a/mlx/backend/metal/jit/indexing.h b/mlx/backend/metal/jit/indexing.h index 77e9541a7..eacad0a51 100644 --- a/mlx/backend/metal/jit/indexing.h +++ b/mlx/backend/metal/jit/indexing.h @@ -1,7 +1,7 @@ // Copyright © 2023-2024 Apple Inc. constexpr std::string_view gather_kernels = R"( -[[kernel]] void gather{0}_{3}_{6}( +[[kernel]] void gather{0}_{3}_{6}_{7}( const device {1}* src [[buffer(0)]], device {1}* out [[buffer(1)]], const constant int* src_shape [[buffer(2)]], @@ -19,7 +19,7 @@ constexpr std::string_view gather_kernels = R"( Indices<{2}, {3}> idxs{{ {{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}}; - return gather_impl<{1}, {2}, {3}, {6}>( + return gather_impl<{1}, {2}, {3}, {6}, {7}>( src, out, src_shape, @@ -34,7 +34,7 @@ constexpr std::string_view gather_kernels = R"( )"; constexpr std::string_view scatter_kernels = R"( -[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}( +[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}_{9}( const device {1}* updates [[buffer(1)]], device mlx_atomic<{1}>* out [[buffer(2)]], const constant int* upd_shape [[buffer(3)]], @@ -54,7 +54,7 @@ constexpr std::string_view scatter_kernels = R"( uint2 gid [[thread_position_in_grid]]) {{ Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}}; - return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}>( + return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}, {9}>( updates, out, upd_shape, diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 83554a6fe..aa43cb0f6 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -46,25 +46,27 @@ MTL::ComputePipelineState* get_unary_kernel( auto lib = d.get_library(lib_name, [&]() { auto in_t = get_type_string(in_type); auto out_t = get_type_string(out_type); - std::ostringstream kernel_source; - kernel_source << metal::utils() << metal::unary_ops() << metal::unary(); - kernel_source << get_template_definition( - "v_" + lib_name, "unary_v", in_t, out_t, op); - kernel_source << get_template_definition( - "v2_" + lib_name, "unary_v2", in_t, out_t, op); - kernel_source << get_template_definition( - "gn4_" + lib_name, "unary_g", in_t, out_t, op, 4); - return kernel_source.str(); + std::string kernel_source = metal::utils(); + concatenate(kernel_source, metal::unary_ops(), metal::unary()); + kernel_source += + get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op); + kernel_source += + get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op); + kernel_source += get_template_definition( + "gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "uint"); + kernel_source += get_template_definition( + "gn4large_" + lib_name, "unary_g", in_t, out_t, op, 4); + return kernel_source; }); return d.get_kernel(kernel_name, lib); } -void add_binary_kernels( +void append_binary_kernels( const std::string lib_name, Dtype in_type, Dtype out_type, const std::string op, - std::ostringstream& kernel_source) { + std::string& kernel_source) { const std::array, 10> kernel_types = {{ {"ss", "binary_ss"}, {"vs", "binary_vs"}, @@ -74,26 +76,24 @@ void add_binary_kernels( {"sv2", "binary_sv2"}, {"vv2", "binary_vv2"}, {"g1", "binary_g_nd1"}, - {"g2", "binary_g_nd2"}, - {"g3", "binary_g_nd3"}, + {"g2large", "binary_g_nd2"}, + {"g3large", "binary_g_nd3"}, }}; + auto in_t = get_type_string(in_type); + auto out_t = get_type_string(out_type); + for (auto& [name, func] : kernel_types) { - std::string template_def; - template_def = get_template_definition( - name + "_" + lib_name, - func, - get_type_string(in_type), - get_type_string(out_type), - op); - kernel_source << template_def; + kernel_source += + get_template_definition(name + "_" + lib_name, func, in_t, out_t, op); } - kernel_source << get_template_definition( - "gn4_" + lib_name, - "binary_g", - get_type_string(in_type), - get_type_string(out_type), - op, - 4); + kernel_source += get_template_definition( + "g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "uint"); + kernel_source += get_template_definition( + "g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "uint"); + kernel_source += get_template_definition( + "gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "uint"); + kernel_source += get_template_definition( + "gn4large_" + lib_name, "binary_g", in_t, out_t, op, 4); } MTL::ComputePipelineState* get_binary_kernel( @@ -104,10 +104,11 @@ MTL::ComputePipelineState* get_binary_kernel( const std::string op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { - std::ostringstream kernel_source; - kernel_source << metal::utils() << metal::binary_ops() << metal::binary(); - add_binary_kernels(lib_name, in_type, out_type, op, kernel_source); - return kernel_source.str(); + std::string kernel_source; + kernel_source = metal::utils(); + concatenate(kernel_source, metal::binary_ops(), metal::binary()); + append_binary_kernels(lib_name, in_type, out_type, op, kernel_source); + return kernel_source; }); return d.get_kernel(kernel_name, lib); } @@ -120,11 +121,10 @@ MTL::ComputePipelineState* get_binary_two_kernel( const std::string op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { - std::ostringstream kernel_source; - kernel_source << metal::utils() << metal::binary_ops() - << metal::binary_two(); - add_binary_kernels(lib_name, in_type, out_type, op, kernel_source); - return kernel_source.str(); + std::string kernel_source = metal::utils(); + concatenate(kernel_source, metal::binary_ops(), metal::binary_two()); + append_binary_kernels(lib_name, in_type, out_type, op, kernel_source); + return kernel_source; }); return d.get_kernel(kernel_name, lib); } @@ -136,24 +136,29 @@ MTL::ComputePipelineState* get_ternary_kernel( const std::string op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { - std::ostringstream kernel_source; + auto t_str = get_type_string(type); + std::string kernel_source = metal::utils(); + concatenate(kernel_source, metal::ternary_ops(), metal::ternary()); const std::array, 5> kernel_types = {{ {"v", "ternary_v"}, {"v2", "ternary_v2"}, {"g1", "ternary_g_nd1"}, - {"g2", "ternary_g_nd2"}, - {"g3", "ternary_g_nd3"}, + {"g2large", "ternary_g_nd2"}, + {"g3large", "ternary_g_nd3"}, }}; - kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary(); for (auto& [name, func] : kernel_types) { - std::string template_def; - template_def = get_template_definition( - name + "_" + lib_name, func, get_type_string(type), op); - kernel_source << template_def; + kernel_source += + get_template_definition(name + "_" + lib_name, func, t_str, op); } - kernel_source << get_template_definition( - "gn4_" + lib_name, "ternary_g", get_type_string(type), op, 4); - return kernel_source.str(); + kernel_source += get_template_definition( + "g2_" + lib_name, "ternary_g_nd2", t_str, op, "uint"); + kernel_source += get_template_definition( + "g3_" + lib_name, "ternary_g_nd3", t_str, op, "uint"); + kernel_source += get_template_definition( + "gn2_" + lib_name, "ternary_g", t_str, op, 2, "uint"); + kernel_source += get_template_definition( + "gn4large_" + lib_name, "ternary_g", t_str, op, 4); + return kernel_source; }); return d.get_kernel(kernel_name, lib); } @@ -165,31 +170,43 @@ MTL::ComputePipelineState* get_copy_kernel( const array& out) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { - std::ostringstream kernel_source; + std::string kernel_source = metal::utils(); + kernel_source += metal::copy(); auto in_type = get_type_string(in.dtype()); auto out_type = get_type_string(out.dtype()); - kernel_source << metal::utils() << metal::copy() - << get_template_definition( - "s_" + lib_name, "copy_s", in_type, out_type) - << get_template_definition( - "v_" + lib_name, "copy_v", in_type, out_type) - << get_template_definition( - "g1_" + lib_name, "copy_g_nd1", in_type, out_type) - << get_template_definition( - "g2_" + lib_name, "copy_g_nd2", in_type, out_type) - << get_template_definition( - "g3_" + lib_name, "copy_g_nd3", in_type, out_type) - << get_template_definition( - "gn4_" + lib_name, "copy_g", in_type, out_type, 4) - << get_template_definition( - "gg1_" + lib_name, "copy_gg_nd1", in_type, out_type) - << get_template_definition( - "gg2_" + lib_name, "copy_gg_nd2", in_type, out_type) - << get_template_definition( - "gg3_" + lib_name, "copy_gg_nd3", in_type, out_type) - << get_template_definition( - "ggn4_" + lib_name, "copy_gg", in_type, out_type, 4); - return kernel_source.str(); + kernel_source += + get_template_definition("s_" + lib_name, "copy_s", in_type, out_type); + kernel_source += + get_template_definition("v_" + lib_name, "copy_v", in_type, out_type); + kernel_source += get_template_definition( + "g1_" + lib_name, "copy_g_nd1", in_type, out_type); + kernel_source += get_template_definition( + "g2_" + lib_name, "copy_g_nd2", in_type, out_type, "int"); + kernel_source += get_template_definition( + "g3_" + lib_name, "copy_g_nd3", in_type, out_type, "int"); + kernel_source += get_template_definition( + "gn2_" + lib_name, "copy_g", in_type, out_type, 2, "int"); + kernel_source += get_template_definition( + "gg1_" + lib_name, "copy_gg_nd1", in_type, out_type); + kernel_source += get_template_definition( + "gg2_" + lib_name, "copy_gg_nd2", in_type, out_type, "int"); + kernel_source += get_template_definition( + "gg3_" + lib_name, "copy_gg_nd3", in_type, out_type, "int"); + kernel_source += get_template_definition( + "ggn2_" + lib_name, "copy_gg", in_type, out_type, 2, "int"); + kernel_source += get_template_definition( + "g2large_" + lib_name, "copy_g_nd2", in_type, out_type); + kernel_source += get_template_definition( + "g3large_" + lib_name, "copy_g_nd3", in_type, out_type); + kernel_source += get_template_definition( + "gn4large_" + lib_name, "copy_g", in_type, out_type, 4); + kernel_source += get_template_definition( + "gg2large_" + lib_name, "copy_gg_nd2", in_type, out_type); + kernel_source += get_template_definition( + "gg3large_" + lib_name, "copy_gg_nd3", in_type, out_type); + kernel_source += get_template_definition( + "ggn4large_" + lib_name, "copy_gg", in_type, out_type, 4); + return kernel_source; }); return d.get_kernel(kernel_name, lib); } diff --git a/mlx/backend/metal/kernels/binary.h b/mlx/backend/metal/kernels/binary.h index d64488e9f..4b260bc30 100644 --- a/mlx/backend/metal/kernels/binary.h +++ b/mlx/backend/metal/kernels/binary.h @@ -77,12 +77,12 @@ template constant const size_t& a_stride, constant const size_t& b_stride, uint index [[thread_position_in_grid]]) { - auto a_idx = elem_to_loc_1(index, a_stride); - auto b_idx = elem_to_loc_1(index, b_stride); + auto a_idx = elem_to_loc_1(index, a_stride); + auto b_idx = elem_to_loc_1(index, b_stride); c[index] = Op()(a[a_idx], b[b_idx]); } -template +template [[kernel]] void binary_g_nd2( device const T* a, device const T* b, @@ -91,13 +91,13 @@ template constant const size_t b_strides[2], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_2(index, a_strides); - auto b_idx = elem_to_loc_2(index, b_strides); - size_t out_idx = index.x + size_t(grid_dim.x) * index.y; + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; c[out_idx] = Op()(a[a_idx], b[b_idx]); } -template +template [[kernel]] void binary_g_nd3( device const T* a, device const T* b, @@ -106,14 +106,18 @@ template constant const size_t b_strides[3], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_3(index, a_strides); - auto b_idx = elem_to_loc_3(index, b_strides); - size_t out_idx = - index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z); + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); c[out_idx] = Op()(a[a_idx], b[b_idx]); } -template +template < + typename T, + typename U, + typename Op, + int N = 1, + typename IdxT = size_t> [[kernel]] void binary_g( device const T* a, device const T* b, @@ -124,13 +128,12 @@ template constant const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_2_nd( + auto idx = elem_to_loc_2_nd( {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); auto xshape = shape[ndim - 1]; - size_t out_idx = - N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); - auto a_xstride = a_strides[ndim - 1]; - auto b_xstride = b_strides[ndim - 1]; + IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); + IdxT a_xstride = a_strides[ndim - 1]; + IdxT b_xstride = b_strides[ndim - 1]; for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { c[out_idx++] = Op()(a[idx.x], b[idx.y]); idx.x += a_xstride; diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index db5de5b59..a9d7044d8 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -9,18 +9,21 @@ #include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary.h" -#define instantiate_binary_all(op, tname, itype, otype) \ - instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ - instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ - instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ - instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ - instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ - instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ - instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ - instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \ - instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \ - instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \ - instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \ +#define instantiate_binary_all(op, tname, itype, otype) \ + instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ + instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ + instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ + instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ + instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ + instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ + instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ + instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \ + instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \ + instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \ + instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \ + instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \ + instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \ + instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op) #define instantiate_binary_integer(op) \ instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ diff --git a/mlx/backend/metal/kernels/binary_two.h b/mlx/backend/metal/kernels/binary_two.h index a4a3130bf..6057dd41b 100644 --- a/mlx/backend/metal/kernels/binary_two.h +++ b/mlx/backend/metal/kernels/binary_two.h @@ -99,14 +99,14 @@ template constant const size_t& a_stride, constant const size_t& b_stride, uint index [[thread_position_in_grid]]) { - auto a_idx = elem_to_loc_1(index, a_stride); - auto b_idx = elem_to_loc_1(index, b_stride); + auto a_idx = elem_to_loc_1(index, a_stride); + auto b_idx = elem_to_loc_1(index, b_stride); auto out = Op()(a[a_idx], b[b_idx]); c[index] = out[0]; d[index] = out[1]; } -template +template [[kernel]] void binary_g_nd2( device const T* a, device const T* b, @@ -116,15 +116,15 @@ template constant const size_t b_strides[2], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_2(index, a_strides); - auto b_idx = elem_to_loc_2(index, b_strides); - size_t out_idx = index.x + size_t(grid_dim.x) * index.y; + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; auto out = Op()(a[a_idx], b[b_idx]); c[out_idx] = out[0]; d[out_idx] = out[1]; } -template +template [[kernel]] void binary_g_nd3( device const T* a, device const T* b, @@ -134,16 +134,20 @@ template constant const size_t b_strides[3], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_3(index, a_strides); - auto b_idx = elem_to_loc_3(index, b_strides); - size_t out_idx = - index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z); + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); auto out = Op()(a[a_idx], b[b_idx]); c[out_idx] = out[0]; d[out_idx] = out[1]; } -template +template < + typename T, + typename U, + typename Op, + int N = 1, + typename IdxT = size_t> [[kernel]] void binary_g( device const T* a, device const T* b, @@ -155,13 +159,12 @@ template constant const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_2_nd( + auto idx = elem_to_loc_2_nd( {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); auto xshape = shape[ndim - 1]; - size_t out_idx = - N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); - auto a_xstride = a_strides[ndim - 1]; - auto b_xstride = b_strides[ndim - 1]; + IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); + IdxT a_xstride = a_strides[ndim - 1]; + IdxT b_xstride = b_strides[ndim - 1]; for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { auto out = Op()(a[idx.x], b[idx.y]); c[out_idx] = out[0]; diff --git a/mlx/backend/metal/kernels/binary_two.metal b/mlx/backend/metal/kernels/binary_two.metal index 6f5e48c0e..da9ac3a5d 100644 --- a/mlx/backend/metal/kernels/binary_two.metal +++ b/mlx/backend/metal/kernels/binary_two.metal @@ -7,18 +7,21 @@ #include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary_two.h" -#define instantiate_binary_all(op, tname, itype, otype) \ - instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ - instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ - instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ - instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ - instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ - instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ - instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ - instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \ - instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \ - instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \ - instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \ +#define instantiate_binary_all(op, tname, itype, otype) \ + instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ + instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ + instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ + instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ + instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ + instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ + instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ + instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \ + instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \ + instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \ + instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \ + instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \ + instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \ + instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op) #define instantiate_binary_float(op) \ instantiate_binary_all(op, float16, half, half) \ diff --git a/mlx/backend/metal/kernels/copy.h b/mlx/backend/metal/kernels/copy.h index 914aebfd6..2113c825a 100644 --- a/mlx/backend/metal/kernels/copy.h +++ b/mlx/backend/metal/kernels/copy.h @@ -42,36 +42,36 @@ template device U* dst [[buffer(1)]], constant const int64_t& src_stride [[buffer(3)]], uint index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_1(index, src_stride); + auto src_idx = elem_to_loc_1(index, src_stride); dst[index] = static_cast(src[src_idx]); } -template +template [[kernel]] void copy_g_nd2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int64_t* src_strides [[buffer(3)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto src_idx = elem_to_loc_2(index, src_strides); - int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y; + auto src_idx = elem_to_loc_2(index, src_strides); + IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y; dst[dst_idx] = static_cast(src[src_idx]); } -template +template [[kernel]] void copy_g_nd3( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int64_t* src_strides [[buffer(3)]], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto src_idx = elem_to_loc_3(index, src_strides); - int64_t dst_idx = - index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); + auto src_idx = elem_to_loc_3(index, src_strides); + IdxT dst_idx = + index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z); dst[dst_idx] = static_cast(src[src_idx]); } -template +template [[kernel]] void copy_g( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], @@ -80,17 +80,16 @@ template constant const int& ndim [[buffer(5)]], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto src_idx = elem_to_loc( + auto src_idx = elem_to_loc( {N * index.x, index.y, index.z}, src_shape, src_strides, ndim); if (N == 1) { - int64_t dst_idx = - index.x + grid_dim.x * (index.y + int64_t(grid_dim.y) * index.z); + IdxT dst_idx = + index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); dst[dst_idx] = static_cast(src[src_idx]); return; } auto xshape = src_shape[ndim - 1]; - int64_t dst_idx = - N * index.x + xshape * (index.y + int64_t(grid_dim.y) * index.z); + IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); auto src_xstride = src_strides[ndim - 1]; for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { dst[dst_idx + i] = static_cast(src[src_idx]); @@ -105,36 +104,36 @@ template constant const int64_t& src_stride [[buffer(3)]], constant const int64_t& dst_stride [[buffer(4)]], uint index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_1(index, src_stride); - auto dst_idx = elem_to_loc_1(index, dst_stride); + auto src_idx = elem_to_loc_1(index, src_stride); + auto dst_idx = elem_to_loc_1(index, dst_stride); dst[dst_idx] = static_cast(src[src_idx]); } -template +template [[kernel]] void copy_gg_nd2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int64_t* src_strides [[buffer(3)]], constant const int64_t* dst_strides [[buffer(4)]], uint2 index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_2(index, src_strides); - auto dst_idx = elem_to_loc_2(index, dst_strides); + auto src_idx = elem_to_loc_2(index, src_strides); + auto dst_idx = elem_to_loc_2(index, dst_strides); dst[dst_idx] = static_cast(src[src_idx]); } -template +template [[kernel]] void copy_gg_nd3( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int64_t* src_strides [[buffer(3)]], constant const int64_t* dst_strides [[buffer(4)]], uint3 index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_3(index, src_strides); - auto dst_idx = elem_to_loc_3(index, dst_strides); + auto src_idx = elem_to_loc_3(index, src_strides); + auto dst_idx = elem_to_loc_3(index, dst_strides); dst[dst_idx] = static_cast(src[src_idx]); } -template +template [[kernel]] void copy_gg( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], @@ -143,7 +142,7 @@ template constant const int64_t* dst_strides [[buffer(4)]], constant const int& ndim [[buffer(5)]], uint3 index [[thread_position_in_grid]]) { - auto idx = elem_to_loc_2_nd( + auto idx = elem_to_loc_2_nd( {N * index.x, index.y, index.z}, src_shape, src_strides, @@ -153,8 +152,8 @@ template dst[idx.y] = static_cast(src[idx.x]); return; } - auto src_xstride = src_strides[ndim - 1]; - auto dst_xstride = dst_strides[ndim - 1]; + IdxT src_xstride = src_strides[ndim - 1]; + IdxT dst_xstride = dst_strides[ndim - 1]; auto xshape = src_shape[ndim - 1]; for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { dst[idx.y] = static_cast(src[idx.x]); diff --git a/mlx/backend/metal/kernels/copy.metal b/mlx/backend/metal/kernels/copy.metal index ffbf2be7c..5c444b30d 100644 --- a/mlx/backend/metal/kernels/copy.metal +++ b/mlx/backend/metal/kernels/copy.metal @@ -4,19 +4,25 @@ #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/copy.h" -#define instantiate_copy_all(tname, itype, otype) \ - instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \ - instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \ - instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \ - instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \ - instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \ - instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype) \ - instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype) \ - instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \ - instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \ - instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \ - instantiate_kernel("gn4_copy" #tname, copy_g, itype, otype, 4) \ - instantiate_kernel("ggn4_copy" #tname, copy_gg, itype, otype, 4) +#define instantiate_copy_all(tname, itype, otype) \ + instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \ + instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \ + instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \ + instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \ + instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \ + instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \ + instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \ + instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \ + instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype, int) \ + instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype, int) \ + instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \ + instantiate_kernel("ggn2_copy" #tname, copy_gg, itype, otype, 2, int) \ + instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \ + instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \ + instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, itype, otype) \ + instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, itype, otype) \ + instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) \ + instantiate_kernel("ggn4large_copy" #tname, copy_gg, itype, otype, 4) #define instantiate_copy_itype(itname, itype) \ instantiate_copy_all(itname ##bool_, itype, bool) \ diff --git a/mlx/backend/metal/kernels/gather.h b/mlx/backend/metal/kernels/gather.h index 4d3997ad8..b38ab6283 100644 --- a/mlx/backend/metal/kernels/gather.h +++ b/mlx/backend/metal/kernels/gather.h @@ -4,7 +4,7 @@ #include "mlx/backend/metal/kernels/indexing.h" -template +template METAL_FUNC void gather_impl( const device T* src [[buffer(0)]], device T* out [[buffer(1)]], @@ -16,18 +16,18 @@ METAL_FUNC void gather_impl( const thread Indices& indices, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - size_t src_idx = 0; + LocT src_idx = 0; for (int i = 0; i < NIDX; ++i) { - size_t idx_loc; + LocT idx_loc; if (IDX_NDIM == 0) { idx_loc = 0; } else if (IDX_NDIM == 1) { - idx_loc = index.x * indices.strides[indices.ndim * i]; + idx_loc = index.x * static_cast(indices.strides[indices.ndim * i]); } else { - idx_loc = index.x * indices.strides[indices.ndim * i]; + idx_loc = index.x * static_cast(indices.strides[indices.ndim * i]); idx_loc += indices.row_contiguous[i] ? index.y - : elem_to_loc( + : elem_to_loc( index.y, &indices.shapes[indices.ndim * i + 1], &indices.strides[indices.ndim * i + 1], @@ -35,17 +35,17 @@ METAL_FUNC void gather_impl( } auto ax = axes[i]; auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]); - src_idx += idx_val * src_strides[ax]; + src_idx += static_cast(idx_val) * static_cast(src_strides[ax]); } - auto src_offset = elem_to_loc(index.z, slice_sizes, src_strides, src_ndim); + auto src_offset = + elem_to_loc(index.z, slice_sizes, src_strides, src_ndim); - size_t out_idx = index.z; + LocT out_idx = index.z; if (IDX_NDIM == 1) { - out_idx += static_cast(grid_dim.z) * index.x; + out_idx += static_cast(grid_dim.z) * index.x; } else if (IDX_NDIM >= 2) { - out_idx += - grid_dim.z * (index.x * static_cast(grid_dim.y) + index.y); + out_idx += grid_dim.z * (index.x * static_cast(grid_dim.y) + index.y); } out[out_idx] = src[src_offset + src_idx]; } diff --git a/mlx/backend/metal/kernels/indexing.h b/mlx/backend/metal/kernels/indexing.h index ca4158df6..05bef96b6 100644 --- a/mlx/backend/metal/kernels/indexing.h +++ b/mlx/backend/metal/kernels/indexing.h @@ -14,7 +14,7 @@ struct Indices { }; template -METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) { +METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) { if (is_unsigned_v) { return idx; } else { diff --git a/mlx/backend/metal/kernels/scatter.h b/mlx/backend/metal/kernels/scatter.h index 9a38e62b1..63b09df3d 100644 --- a/mlx/backend/metal/kernels/scatter.h +++ b/mlx/backend/metal/kernels/scatter.h @@ -10,7 +10,8 @@ template < typename Op, int NIDX, bool UPD_ROW_CONTIG, - int NWORK> + int NWORK, + typename LocT> METAL_FUNC void scatter_impl( const device T* updates, device mlx_atomic* out, @@ -28,29 +29,31 @@ METAL_FUNC void scatter_impl( Op op; auto ind_idx = gid.y * NWORK; - size_t out_offset = 0; + LocT out_offset = 0; if (upd_size > 1) { - out_offset = - elem_to_loc(gid.x, upd_shape + indices.ndim, out_strides, out_ndim); + out_offset = elem_to_loc( + gid.x, upd_shape + indices.ndim, out_strides, out_ndim); } for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) { - size_t out_idx = out_offset; + LocT out_idx = out_offset; for (int i = 0; i < NIDX; ++i) { auto idx_loc = indices.row_contiguous[i] ? ind_idx - : elem_to_loc( + : elem_to_loc( ind_idx, &indices.shapes[indices.ndim * i], &indices.strides[indices.ndim * i], indices.ndim); auto ax = axes[i]; auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]); - out_idx += idx_val * out_strides[ax]; + out_idx += + static_cast(idx_val) * static_cast(out_strides[ax]); } - auto upd_idx = ind_idx * upd_size + gid.x; + auto upd_idx = ind_idx * static_cast(upd_size) + gid.x; if constexpr (!UPD_ROW_CONTIG) { - upd_idx = elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim); + upd_idx = + elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim); } op.atomic_update(out, updates[upd_idx], out_idx); } diff --git a/mlx/backend/metal/kernels/ternary.h b/mlx/backend/metal/kernels/ternary.h index 2bd1242c9..e19ea23df 100644 --- a/mlx/backend/metal/kernels/ternary.h +++ b/mlx/backend/metal/kernels/ternary.h @@ -32,13 +32,13 @@ template constant const size_t& b_strides, constant const size_t& c_strides, uint index [[thread_position_in_grid]]) { - auto a_idx = elem_to_loc_1(index, a_strides); - auto b_idx = elem_to_loc_1(index, b_strides); - auto c_idx = elem_to_loc_1(index, c_strides); + auto a_idx = elem_to_loc_1(index, a_strides); + auto b_idx = elem_to_loc_1(index, b_strides); + auto c_idx = elem_to_loc_1(index, c_strides); d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]); } -template +template [[kernel]] void ternary_g_nd2( device const bool* a, device const T* b, @@ -49,14 +49,14 @@ template constant const size_t c_strides[2], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_2(index, a_strides); - auto b_idx = elem_to_loc_2(index, b_strides); - auto c_idx = elem_to_loc_2(index, c_strides); - size_t out_idx = index.x + size_t(grid_dim.x) * index.y; + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + auto c_idx = elem_to_loc_2(index, c_strides); + IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); } -template +template [[kernel]] void ternary_g_nd3( device const bool* a, device const T* b, @@ -67,15 +67,14 @@ template constant const size_t c_strides[3], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_3(index, a_strides); - auto b_idx = elem_to_loc_3(index, b_strides); - auto c_idx = elem_to_loc_3(index, c_strides); - size_t out_idx = - index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z); + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + auto c_idx = elem_to_loc_3(index, c_strides); + IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); } -template +template [[kernel]] void ternary_g( device const bool* a, device const T* b, @@ -88,7 +87,7 @@ template constant const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_3_nd( + auto idx = elem_to_loc_3_nd( {N * index.x, index.y, index.z}, shape, a_strides, @@ -96,11 +95,10 @@ template c_strides, ndim); auto xshape = shape[ndim - 1]; - size_t out_idx = - N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); - auto a_xstride = a_strides[ndim - 1]; - auto b_xstride = b_strides[ndim - 1]; - auto c_xstride = c_strides[ndim - 1]; + IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); + IdxT a_xstride = a_strides[ndim - 1]; + IdxT b_xstride = b_strides[ndim - 1]; + IdxT c_xstride = c_strides[ndim - 1]; for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { d[out_idx++] = Op()(a[idx.x], b[idx.y], c[idx.z]); idx.x += a_xstride; diff --git a/mlx/backend/metal/kernels/ternary.metal b/mlx/backend/metal/kernels/ternary.metal index f12e0048f..a509dacce 100644 --- a/mlx/backend/metal/kernels/ternary.metal +++ b/mlx/backend/metal/kernels/ternary.metal @@ -8,13 +8,16 @@ #include "mlx/backend/metal/kernels/ternary_ops.h" #include "mlx/backend/metal/kernels/ternary.h" -#define instantiate_ternary_all(op, tname, type) \ - instantiate_kernel("v_" #op #tname, ternary_v, type, op) \ - instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \ - instantiate_kernel("gn4_" #op #tname, ternary_g, type, op, 4) \ - instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \ - instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \ - instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op) +#define instantiate_ternary_all(op, tname, type) \ + instantiate_kernel("v_" #op #tname, ternary_v, type, op) \ + instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \ + instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, uint) \ + instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \ + instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, uint) \ + instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, uint) \ + instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \ + instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \ + instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \ #define instantiate_ternary_types(op) \ instantiate_ternary_all(op, bool_, bool) \ diff --git a/mlx/backend/metal/kernels/unary.h b/mlx/backend/metal/kernels/unary.h index 402d936c7..acfe176ef 100644 --- a/mlx/backend/metal/kernels/unary.h +++ b/mlx/backend/metal/kernels/unary.h @@ -18,7 +18,12 @@ template out[offset] = Op()(in[offset]); } -template +template < + typename T, + typename U, + typename Op, + int N = 1, + typename IdxT = size_t> [[kernel]] void unary_g( device const T* in, device U* out, @@ -27,12 +32,11 @@ template device const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto idx = - elem_to_loc({N * index.x, index.y, index.z}, in_shape, in_strides, ndim); + auto idx = elem_to_loc( + {N * index.x, index.y, index.z}, in_shape, in_strides, ndim); auto xshape = in_shape[ndim - 1]; - auto xstride = in_strides[ndim - 1]; - size_t out_idx = - N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); + IdxT xstride = in_strides[ndim - 1]; + IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { out[out_idx++] = Op()(in[idx]); idx += xstride; diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index 463708ab6..b196f023f 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -5,11 +5,13 @@ #include "mlx/backend/metal/kernels/unary_ops.h" #include "mlx/backend/metal/kernels/unary.h" -#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \ - instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \ - instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \ - instantiate_kernel("gn4_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4) - +#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \ + instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \ + instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \ + instantiate_kernel( \ + "gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, uint) \ + instantiate_kernel( \ + "gn4large" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4) #define instantiate_unary_all_same(op, tname, type) \ instantiate_unary_all(op, tname, tname, type, type) diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index 151c6a64d..1e1b91c0b 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -89,44 +89,45 @@ struct Limits { /////////////////////////////////////////////////////////////////////////////// // Single Array with generic dims -template -METAL_FUNC stride_t elem_to_loc( +template +METAL_FUNC IdxT elem_to_loc( uint elem, constant const int* shape, - constant const stride_t* strides, + constant const StrideT* strides, int ndim) { - stride_t loc = 0; + IdxT loc = 0; for (int i = ndim - 1; i >= 0 && elem > 0; --i) { - loc += (elem % shape[i]) * strides[i]; + loc += (elem % shape[i]) * IdxT(strides[i]); elem /= shape[i]; } return loc; } -template -METAL_FUNC stride_t elem_to_loc( - stride_t elem, +template +METAL_FUNC IdxT elem_to_loc( + StrideT elem, constant const int* shape, - constant const stride_t* strides, + constant const StrideT* strides, int ndim) { - stride_t loc = 0; + IdxT loc = 0; for (int i = ndim - 1; i >= 0 && elem > 0; --i) { - loc += (elem % shape[i]) * strides[i]; + loc += (elem % shape[i]) * IdxT(strides[i]); elem /= shape[i]; } return loc; } // Non templated version to handle arbitrary dims -template -METAL_FUNC stride_t elem_to_loc( +template +METAL_FUNC IdxT elem_to_loc( uint3 elem, constant const int* shape, - constant const stride_t* strides, + constant const StrideT* strides, int ndim) { - stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2]; + IdxT loc = + elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]); for (int d = ndim - 3; d >= 0; --d) { - loc += (elem.z % shape[d]) * strides[d]; + loc += (elem.z % shape[d]) * IdxT(strides[d]); elem.z /= shape[d]; } return loc; @@ -135,61 +136,65 @@ METAL_FUNC stride_t elem_to_loc( /////////////////////////////////////////////////////////////////////////////// // Single Array with fixed N dims -template -METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t& stride) { - return elem * stride; +template +METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const StrideT& stride) { + return elem * IdxT(stride); } -template -METAL_FUNC stride_t -elem_to_loc_2(uint2 elem, constant const stride_t strides[2]) { - return elem.x * strides[1] + elem.y * strides[0]; +template +METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const StrideT strides[2]) { + return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]); } -template -METAL_FUNC stride_t -elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) { - return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0]; +template +METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const StrideT strides[3]) { + return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) + + elem.z * IdxT(strides[0]); } /////////////////////////////////////////////////////////////////////////////// // Multiple Arrays with generic dims -template -METAL_FUNC ulong2 elem_to_loc_2_nd( +template +METAL_FUNC vec elem_to_loc_2_nd( uint3 elem, constant const int* shape, - constant const stride_t* a_strides, - constant const stride_t* b_strides, + constant const StrideT* a_strides, + constant const StrideT* b_strides, int ndim) { - ulong2 loc = { - ulong(elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]), - ulong(elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])}; + vec loc = { + IdxT( + elem.x * IdxT(a_strides[ndim - 1]) + + IdxT(elem.y) * IdxT(a_strides[ndim - 2])), + IdxT( + elem.x * IdxT(b_strides[ndim - 1]) + + elem.y * IdxT(b_strides[ndim - 2]))}; for (int d = ndim - 3; d >= 0; --d) { uint l = elem.z % shape[d]; - loc.x += l * a_strides[d]; - loc.y += l * b_strides[d]; + loc.x += l * IdxT(a_strides[d]); + loc.y += l * IdxT(b_strides[d]); elem.z /= shape[d]; } return loc; } -METAL_FUNC ulong3 elem_to_loc_3_nd( +template +METAL_FUNC vec elem_to_loc_3_nd( uint3 elem, constant const int* shape, constant const size_t* a_strides, constant const size_t* b_strides, constant const size_t* c_strides, int ndim) { - ulong3 loc = { - elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2], - elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2], - elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2]}; + vec loc = { + elem.x * IdxT(a_strides[ndim - 1]) + elem.y * IdxT(a_strides[ndim - 2]), + elem.x * IdxT(b_strides[ndim - 1]) + elem.y * IdxT(b_strides[ndim - 2]), + elem.x * IdxT(c_strides[ndim - 1]) + elem.y * IdxT(c_strides[ndim - 2])}; for (int d = ndim - 3; d >= 0; --d) { uint l = elem.z % shape[d]; - loc.x += l * a_strides[d]; - loc.y += l * b_strides[d]; - loc.z += l * c_strides[d]; + loc.x += l * IdxT(a_strides[d]); + loc.y += l * IdxT(b_strides[d]); + loc.z += l * IdxT(c_strides[d]); elem.z /= shape[d]; } return loc; diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index 0d0fdc657..50c4c8cce 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -36,27 +36,31 @@ void ternary_op_gpu_inplace( }; auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse(); - bool use_2d = out.data_size() > UINT_MAX; + bool large = out.data_size() > UINT_MAX; auto ndim = shape.size(); - int work_per_thread = (topt == TernaryOpType::General) ? 4 : 1; - std::string kernel_name; - { - std::ostringstream kname; - if (topt == TernaryOpType::General) { - kname << "g"; - if (shape.size() <= 3) { - kname << shape.size(); - } else if (work_per_thread > 1) { - kname << "n" << work_per_thread; - } - } else if (use_2d) { - kname << "v2"; - } else { - kname << "v"; - } - kname << "_" << op << type_to_name(b); - kernel_name = kname.str(); + int work_per_thread; + if (topt == TernaryOpType::General) { + work_per_thread = large ? 4 : 2; + } else { + work_per_thread = 1; } + std::string kernel_name; + if (topt == TernaryOpType::General) { + kernel_name = "g"; + if (shape.size() <= 3) { + kernel_name += std::to_string(shape.size()); + } else if (work_per_thread > 1) { + concatenate(kernel_name, "n", std::to_string(work_per_thread)); + } + if (large) { + kernel_name += "large"; + } + } else if (large) { + kernel_name = "v2"; + } else { + kernel_name = "v"; + } + concatenate(kernel_name, "_", op, type_to_name(b)); auto& d = metal::device(s.device); @@ -107,8 +111,8 @@ void ternary_op_gpu_inplace( thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) + : MTL::Size(nthreads, 1, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index e9baad065..c1004df22 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -35,16 +35,19 @@ void unary_op_gpu_inplace( }; auto [shape, strides] = maybe_collapse(); int ndim = shape.size(); - int work_per_thread = !contig ? 4 : 1; size_t nthreads = contig ? in.data_size() : in.size(); - bool use_2d = nthreads > UINT32_MAX; + bool large = nthreads > UINT32_MAX; + int work_per_thread = !contig && large ? 4 : 1; std::string kernel_name; if (contig) { - kernel_name = (use_2d ? "v2" : "v"); + kernel_name = (large ? "v2" : "v"); } else { - kernel_name = (work_per_thread == 4 ? "gn4" : "g"); + kernel_name = "gn" + std::to_string(work_per_thread); + if (large) { + kernel_name += "_large"; + } } - kernel_name += "_" + op + type_to_name(in) + type_to_name(out); + concatenate(kernel_name, "_", op, type_to_name(in), type_to_name(out)); auto kernel = get_unary_kernel(d, kernel_name, in.dtype(), out.dtype(), op); auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); @@ -73,8 +76,8 @@ void unary_op_gpu_inplace( thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) + : MTL::Size(nthreads, 1, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index f2a9c7b20..366da6287 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -61,4 +61,15 @@ inline void debug_set_primitive_buffer_label( std::string get_primitive_string(Primitive* primitive); +template +void concatenate(std::string& acc, T first) { + acc += first; +} + +template +void concatenate(std::string& acc, T first, Args... args) { + acc += first; + concatenate(acc, args...); +} + } // namespace mlx::core