From c6a20b427ac624f4b9ac9d6f8e4c5847f1dbb672 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 6 Jun 2025 11:37:40 -0700 Subject: [PATCH] Improve metal elementwise kernels (#2247) * improve metal elementwise kernels * compile and copy * fix jit --- mlx/backend/metal/binary.cpp | 15 ++- mlx/backend/metal/compiled.cpp | 26 +++++- mlx/backend/metal/copy.cpp | 15 +-- mlx/backend/metal/jit_kernels.cpp | 52 ++++++++--- mlx/backend/metal/kernels/binary.h | 60 +++++++++--- mlx/backend/metal/kernels/binary.metal | 47 ++++++---- mlx/backend/metal/kernels/binary_two.h | 102 +++++++++++++++------ mlx/backend/metal/kernels/binary_two.metal | 39 +++++--- mlx/backend/metal/kernels/copy.h | 52 ++++++++--- mlx/backend/metal/kernels/copy.metal | 20 ++-- mlx/backend/metal/kernels/ternary.h | 22 ++++- mlx/backend/metal/kernels/ternary.metal | 14 ++- mlx/backend/metal/kernels/unary.h | 22 ++++- mlx/backend/metal/kernels/unary.metal | 88 ++++++++++-------- mlx/backend/metal/ternary.cpp | 4 +- mlx/backend/metal/unary.cpp | 4 +- mlx/backend/metal/utils.h | 4 + 17 files changed, 412 insertions(+), 174 deletions(-) diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index c3c67e4d5..54aaf153c 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -31,13 +31,13 @@ std::string get_kernel_name( kname = "ss"; break; case BinaryOpType::ScalarVector: - kname = (large ? "sv2" : "sv"); + kname = "sv"; break; case BinaryOpType::VectorScalar: - kname = (large ? "vs2" : "vs"); + kname = "vs"; break; case BinaryOpType::VectorVector: - kname = (large ? "vv2" : "vv"); + kname = "vv"; break; case BinaryOpType::General: kname = "g"; @@ -51,6 +51,13 @@ std::string get_kernel_name( } break; } + if (bopt != BinaryOpType::General && bopt != BinaryOpType::ScalarScalar) { + if (large) { + kname += "2"; + } else if (work_per_thread > 1) { + kname += "n"; + } + } concatenate(kname, "_", op, type_to_name(a)); return kname; } @@ -90,7 +97,7 @@ void binary_op_gpu_inplace( work_per_thread = large ? 4 : 2; } else { large = out.data_size() > UINT32_MAX; - work_per_thread = get_work_per_thread(a.dtype()); + work_per_thread = get_work_per_thread(a.dtype(), out.data_size()); } std::string kernel_name = get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread); diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index 6a67b4f57..88edc6baa 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -278,7 +278,21 @@ void Compiled::eval_gpu( /* ndim = */ 0, /* dynamic_dims = */ false, /* use_big_index = */ false, - /* work_per_thread = */ work_per_thread); + /* work_per_thread = */ 1); + if (work_per_thread > 1) { + build_kernel( + kernel, + kernel_lib_ + "_contiguous_n", + inputs_, + outputs_, + tape_, + is_constant_, + /* contiguous = */ true, + /* ndim = */ 0, + /* dynamic_dims = */ false, + /* use_big_index = */ false, + /* work_per_thread = */ work_per_thread); + } build_kernel( kernel, kernel_lib_ + "_contiguous_large", @@ -358,12 +372,20 @@ void Compiled::eval_gpu( int ndim = shape.size(); bool dynamic = ndim >= 8; auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_"); + int work_per_thread = 1; if (!contiguous) { if (dynamic) { kernel_name += "dynamic"; } else { kernel_name += std::to_string(shape.size()); } + work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1; + } else { + work_per_thread = + get_work_per_thread(outputs[0].dtype(), outputs[0].data_size()); + if (work_per_thread > 1 && !large) { + kernel_name += "_n"; + } } if (large) { kernel_name += "_large"; @@ -420,7 +442,6 @@ void Compiled::eval_gpu( // Launch the kernel if (contiguous) { - int work_per_thread = get_work_per_thread(outputs[0].dtype()); size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread); MTL::Size group_dims( std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); @@ -433,7 +454,6 @@ void Compiled::eval_gpu( size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; size_t rest = outputs[0].size() / (dim0 * dim1); - int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1; dim0 = (dim0 + work_per_thread - 1) / work_per_thread; NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); int pow2; diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 8dfe15c11..8123b793e 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -55,10 +55,10 @@ void copy_gpu_inplace( std::string kernel_name; switch (ctype) { case CopyType::Scalar: - kernel_name = (large ? "s2" : "s"); + kernel_name = large ? "s2" : "s"; break; case CopyType::Vector: - kernel_name = (large ? "v2" : "v"); + kernel_name = large ? "v2" : "v"; break; case CopyType::General: kernel_name = "g"; @@ -85,7 +85,10 @@ void copy_gpu_inplace( } } } else { - work_per_thread = get_work_per_thread(in.dtype()); + work_per_thread = get_work_per_thread(out.dtype(), out.data_size()); + if (work_per_thread > 1) { + kernel_name += "n"; + } } concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out)); auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out) @@ -170,9 +173,10 @@ void fill_gpu(const array& val, array& out, const Stream& s) { } out.set_data(allocator::malloc(out.nbytes())); bool large = out.data_size() > UINT32_MAX; + int work_per_thread = get_work_per_thread(out.dtype(), out.data_size()); auto& d = metal::device(s.device); - std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" + - type_to_name(val) + type_to_name(out); + std::string kernel_name = large ? "s2" : (work_per_thread > 1 ? "sn" : "s"); + concatenate(kernel_name, "_copy", type_to_name(val), type_to_name(out)); auto kernel = get_copy_kernel(d, kernel_name, val, out); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); @@ -180,7 +184,6 @@ void fill_gpu(const array& val, array& out, const Stream& s) { compute_encoder.set_input_array(val, 0); compute_encoder.set_output_array(out, 1); - int work_per_thread = get_work_per_thread(val.dtype()); auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 5206c9b54..15e21af6c 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -41,7 +41,11 @@ MTL::ComputePipelineState* get_unary_kernel( std::string kernel_source = metal::utils(); concatenate(kernel_source, metal::unary_ops(), metal::unary()); kernel_source += - get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op); + get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op, 1); + if (get_work_per_thread(in_type) > 1) { + kernel_source += + get_template_definition("vn_" + lib_name, "unary_v", in_t, out_t, op); + } kernel_source += get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op); kernel_source += get_template_definition( @@ -59,11 +63,8 @@ void append_binary_kernels( Dtype out_type, const std::string op, std::string& kernel_source) { - const std::array, 10> kernel_types = {{ + const std::array, 7> kernel_types = {{ {"ss", "binary_ss"}, - {"vs", "binary_vs"}, - {"sv", "binary_sv"}, - {"vv", "binary_vv"}, {"vs2", "binary_vs2"}, {"sv2", "binary_sv2"}, {"vv2", "binary_vv2"}, @@ -78,6 +79,22 @@ void append_binary_kernels( kernel_source += get_template_definition(name + "_" + lib_name, func, in_t, out_t, op); } + kernel_source += get_template_definition( + "vs_" + lib_name, "binary_vs", in_t, out_t, op, 1); + kernel_source += get_template_definition( + "sv_" + lib_name, "binary_sv", in_t, out_t, op, 1); + kernel_source += get_template_definition( + "vv_" + lib_name, "binary_vv", in_t, out_t, op, 1); + + if (get_work_per_thread(in_type) > 1) { + kernel_source += get_template_definition( + "vsn_" + lib_name, "binary_vs", in_t, out_t, op); + kernel_source += get_template_definition( + "svn_" + lib_name, "binary_sv", in_t, out_t, op); + kernel_source += get_template_definition( + "vvn_" + lib_name, "binary_vv", in_t, out_t, op); + } + kernel_source += get_template_definition( "g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int"); kernel_source += get_template_definition( @@ -133,8 +150,7 @@ MTL::ComputePipelineState* get_ternary_kernel( auto t_str = get_type_string(type); std::string kernel_source = metal::utils(); concatenate(kernel_source, metal::ternary_ops(), metal::ternary()); - const std::array, 5> kernel_types = {{ - {"v", "ternary_v"}, + const std::array, 4> kernel_types = {{ {"v2", "ternary_v2"}, {"g1large", "ternary_g_nd1"}, {"g2large", "ternary_g_nd2"}, @@ -144,6 +160,13 @@ MTL::ComputePipelineState* get_ternary_kernel( kernel_source += get_template_definition(name + "_" + lib_name, func, t_str, op); } + if (get_work_per_thread(type) > 1) { + kernel_source += + get_template_definition("vn_" + lib_name, "ternary_v", t_str, op); + } + + kernel_source += + get_template_definition("v_" + lib_name, "ternary_v", t_str, op, 1); kernel_source += get_template_definition( "g1_" + lib_name, "ternary_g_nd1", t_str, op, "int"); kernel_source += get_template_definition( @@ -170,15 +193,22 @@ MTL::ComputePipelineState* get_copy_kernel( kernel_source += metal::copy(); auto in_type = get_type_string(in.dtype()); auto out_type = get_type_string(out.dtype()); - kernel_source += - get_template_definition("s_" + lib_name, "copy_s", in_type, out_type); + kernel_source += get_template_definition( + "s_" + lib_name, "copy_s", in_type, out_type, 1); kernel_source += get_template_definition("s2_" + lib_name, "copy_s2", in_type, out_type); - kernel_source += - get_template_definition("v_" + lib_name, "copy_v", in_type, out_type); + kernel_source += get_template_definition( + "v_" + lib_name, "copy_v", in_type, out_type, 1); kernel_source += get_template_definition("v2_" + lib_name, "copy_v2", in_type, out_type); + if (get_work_per_thread(out.dtype()) > 1) { + kernel_source += get_template_definition( + "sn_" + lib_name, "copy_s", in_type, out_type); + kernel_source += get_template_definition( + "vn_" + lib_name, "copy_v", in_type, out_type); + } + kernel_source += get_template_definition( "g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int"); kernel_source += get_template_definition( diff --git a/mlx/backend/metal/kernels/binary.h b/mlx/backend/metal/kernels/binary.h index ffc33ad82..f1df88535 100644 --- a/mlx/backend/metal/kernels/binary.h +++ b/mlx/backend/metal/kernels/binary.h @@ -17,8 +17,14 @@ template ::n> constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - c[index + i] = Op()(a[0], b[index + i]); + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } } } @@ -30,8 +36,14 @@ template ::n> constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - c[index + i] = Op()(a[index + i], b[0]); + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } } } @@ -43,8 +55,14 @@ template ::n> constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - c[index + i] = Op()(a[index + i], b[index + i]); + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } } } @@ -57,8 +75,14 @@ template ::n> uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - c[offset + i] = Op()(a[0], b[offset + i]); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } } } @@ -71,8 +95,14 @@ template ::n> uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - c[offset + i] = Op()(a[offset + i], b[0]); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } } } @@ -85,8 +115,14 @@ template ::n> uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - c[offset + i] = Op()(a[offset + i], b[offset + i]); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } } } diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index 1d555fefa..17ed13c57 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -9,11 +9,16 @@ #include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary.h" -#define instantiate_binary_all(op, tname, itype, otype) \ +#define instantiate_binary_work_per_thread(op, tname, itype, otype) \ + instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \ + instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \ + instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op) \ + +#define instantiate_binary_base(op, tname, itype, otype) \ instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ - instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ - instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ - instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ + instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \ + instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \ + instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \ instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ @@ -26,15 +31,19 @@ instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \ instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op) -#define instantiate_binary_integer(op) \ - instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ - instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ - instantiate_binary_all(op, uint32, uint32_t, uint32_t) \ - instantiate_binary_all(op, uint64, uint64_t, uint64_t) \ - instantiate_binary_all(op, int8, int8_t, int8_t) \ - instantiate_binary_all(op, int16, int16_t, int16_t) \ - instantiate_binary_all(op, int32, int32_t, int32_t) \ - instantiate_binary_all(op, int64, int64_t, int64_t) +#define instantiate_binary_all(op, tname, itype, otype) \ + instantiate_binary_base(op, tname, itype, otype) \ + instantiate_binary_work_per_thread(op, tname, itype, otype) + +#define instantiate_binary_integer(op) \ + instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ + instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ + instantiate_binary_all(op, uint32, uint32_t, uint32_t) \ + instantiate_binary_base(op, uint64, uint64_t, uint64_t) \ + instantiate_binary_all(op, int8, int8_t, int8_t) \ + instantiate_binary_all(op, int16, int16_t, int16_t) \ + instantiate_binary_all(op, int32, int32_t, int32_t) \ + instantiate_binary_base(op, int64, int64_t, int64_t) #define instantiate_binary_float(op) \ instantiate_binary_all(op, float16, half, half) \ @@ -44,7 +53,7 @@ #define instantiate_binary_types(op) \ instantiate_binary_all(op, bool_, bool, bool) \ instantiate_binary_integer(op) \ - instantiate_binary_all(op, complex64, complex64_t, complex64_t) \ + instantiate_binary_base(op, complex64, complex64_t, complex64_t)\ instantiate_binary_float(op) #define instantiate_binary_types_bool(op) \ @@ -52,15 +61,15 @@ instantiate_binary_all(op, uint8, uint8_t, bool) \ instantiate_binary_all(op, uint16, uint16_t, bool) \ instantiate_binary_all(op, uint32, uint32_t, bool) \ - instantiate_binary_all(op, uint64, uint64_t, bool) \ + instantiate_binary_base(op, uint64, uint64_t, bool) \ instantiate_binary_all(op, int8, int8_t, bool) \ instantiate_binary_all(op, int16, int16_t, bool) \ instantiate_binary_all(op, int32, int32_t, bool) \ - instantiate_binary_all(op, int64, int64_t, bool) \ + instantiate_binary_base(op, int64, int64_t, bool) \ instantiate_binary_all(op, float16, half, bool) \ instantiate_binary_all(op, float32, float, bool) \ instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \ - instantiate_binary_all(op, complex64, complex64_t, bool) + instantiate_binary_base(op, complex64, complex64_t, bool) instantiate_binary_types(Add) instantiate_binary_types(Divide) @@ -71,7 +80,7 @@ instantiate_binary_types_bool(Less) instantiate_binary_types_bool(LessEqual) instantiate_binary_types_bool(NotEqual) instantiate_binary_float(LogAddExp) -instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t) +instantiate_binary_base(LogAddExp, complex64, complex64_t, complex64_t) instantiate_binary_types(Maximum) instantiate_binary_types(Minimum) instantiate_binary_types(Multiply) @@ -84,7 +93,7 @@ instantiate_binary_float(ArcTan2) instantiate_binary_all(NaNEqual, float16, half, bool) instantiate_binary_all(NaNEqual, float32, float, bool) instantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool) -instantiate_binary_all(NaNEqual, complex64, complex64_t, bool) +instantiate_binary_base(NaNEqual, complex64, complex64_t, bool) instantiate_binary_all(LogicalOr, bool_, bool, bool) instantiate_binary_all(LogicalAnd, bool_, bool, bool) diff --git a/mlx/backend/metal/kernels/binary_two.h b/mlx/backend/metal/kernels/binary_two.h index e261d33c4..4455e4ca9 100644 --- a/mlx/backend/metal/kernels/binary_two.h +++ b/mlx/backend/metal/kernels/binary_two.h @@ -21,10 +21,18 @@ template ::n> constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - auto out = Op()(a[0], b[index + i]); - c[index + i] = out[0]; - d[index + i] = out[1]; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } } } @@ -37,10 +45,18 @@ template ::n> constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - auto out = Op()(a[index + i], b[0]); - c[index + i] = out[0]; - d[index + i] = out[1]; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } } } @@ -53,10 +69,18 @@ template ::n> constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - auto out = Op()(a[index + i], b[index + i]); - c[index + i] = out[0]; - d[index + i] = out[1]; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } } } @@ -69,11 +93,19 @@ template ::n> constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - auto out = Op()(a[0], b[offset + i]); - c[offset + i] = out[0]; - d[offset + i] = out[1]; + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } } } @@ -86,11 +118,19 @@ template ::n> constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - auto out = Op()(a[offset + i], b[0]); - c[offset + i] = out[0]; - d[offset + i] = out[1]; + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } } } @@ -103,11 +143,19 @@ template ::n> constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - auto out = Op()(a[offset + i], b[offset + i]); - c[offset + i] = out[0]; - d[offset + i] = out[1]; + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } } } diff --git a/mlx/backend/metal/kernels/binary_two.metal b/mlx/backend/metal/kernels/binary_two.metal index 984a28320..c7d3ecdf0 100644 --- a/mlx/backend/metal/kernels/binary_two.metal +++ b/mlx/backend/metal/kernels/binary_two.metal @@ -7,11 +7,16 @@ #include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary_two.h" -#define instantiate_binary_all(op, tname, itype, otype) \ +#define instantiate_binary_work_per_thread(op, tname, itype, otype) \ + instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \ + instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \ + instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op) + +#define instantiate_binary_base(op, tname, itype, otype) \ instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ - instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ - instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ - instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ + instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \ + instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \ + instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \ instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ @@ -24,22 +29,26 @@ instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \ instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op) +#define instantiate_binary_all(op, tname, itype, otype) \ + instantiate_binary_base(op, tname, itype, otype) \ + instantiate_binary_work_per_thread(op, tname, itype, otype) + #define instantiate_binary_float(op) \ instantiate_binary_all(op, float16, half, half) \ instantiate_binary_all(op, float32, float, float) \ instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t) -#define instantiate_binary_types(op) \ - instantiate_binary_all(op, bool_, bool, bool) \ - instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ - instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ - instantiate_binary_all(op, uint32, uint32_t, uint32_t) \ - instantiate_binary_all(op, uint64, uint64_t, uint64_t) \ - instantiate_binary_all(op, int8, int8_t, int8_t) \ - instantiate_binary_all(op, int16, int16_t, int16_t) \ - instantiate_binary_all(op, int32, int32_t, int32_t) \ - instantiate_binary_all(op, int64, int64_t, int64_t) \ - instantiate_binary_all(op, complex64, complex64_t, complex64_t) \ +#define instantiate_binary_types(op) \ + instantiate_binary_all(op, bool_, bool, bool) \ + instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ + instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ + instantiate_binary_all(op, uint32, uint32_t, uint32_t) \ + instantiate_binary_base(op, uint64, uint64_t, uint64_t) \ + instantiate_binary_all(op, int8, int8_t, int8_t) \ + instantiate_binary_all(op, int16, int16_t, int16_t) \ + instantiate_binary_all(op, int32, int32_t, int32_t) \ + instantiate_binary_base(op, int64, int64_t, int64_t) \ + instantiate_binary_base(op, complex64, complex64_t, complex64_t) \ instantiate_binary_float(op) instantiate_binary_types(DivMod) // clang-format on diff --git a/mlx/backend/metal/kernels/copy.h b/mlx/backend/metal/kernels/copy.h index 2469d1f3d..cf22347ee 100644 --- a/mlx/backend/metal/kernels/copy.h +++ b/mlx/backend/metal/kernels/copy.h @@ -1,52 +1,76 @@ // Copyright © 2024 Apple Inc. -template ::n> +template ::n> [[kernel]] void copy_s( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - dst[index + i] = static_cast(src[0]); + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + dst[index + i] = static_cast(src[0]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[index + i] = static_cast(src[0]); + } } } -template ::n> +template ::n> [[kernel]] void copy_v( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - dst[index + i] = static_cast(src[index + i]); + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + dst[index + i] = static_cast(src[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[index + i] = static_cast(src[index + i]); + } } } -template ::n> +template ::n> [[kernel]] void copy_s2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - dst[offset + i] = static_cast(src[0]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + dst[offset + i] = static_cast(src[0]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[offset + i] = static_cast(src[0]); + } } } -template ::n> +template ::n> [[kernel]] void copy_v2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - dst[offset + i] = static_cast(src[offset + i]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } } } diff --git a/mlx/backend/metal/kernels/copy.metal b/mlx/backend/metal/kernels/copy.metal index bbf268158..fcf8884f8 100644 --- a/mlx/backend/metal/kernels/copy.metal +++ b/mlx/backend/metal/kernels/copy.metal @@ -4,9 +4,13 @@ #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/copy.h" -#define instantiate_copy_all(tname, itype, otype) \ - instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \ - instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \ +#define instantiate_copy_work_per_thread(tname, itype, otype) \ + instantiate_kernel("sn_copy" #tname, copy_s, itype, otype) \ + instantiate_kernel("vn_copy" #tname, copy_v, itype, otype) + +#define instantiate_copy_base(tname, itype, otype) \ + instantiate_kernel("s_copy" #tname, copy_s, itype, otype, 1) \ + instantiate_kernel("v_copy" #tname, copy_v, itype, otype, 1) \ instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \ instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \ instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \ @@ -18,6 +22,10 @@ instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \ instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) +#define instantiate_copy_all(tname, itype, otype) \ + instantiate_copy_base(tname, itype, otype) \ + instantiate_copy_work_per_thread(tname, itype, otype) + #define instantiate_copy_same(tname, type) \ instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \ instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \ @@ -42,15 +50,15 @@ instantiate_copy_all(itname ##uint8, itype, uint8_t) \ instantiate_copy_all(itname ##uint16, itype, uint16_t) \ instantiate_copy_all(itname ##uint32, itype, uint32_t) \ - instantiate_copy_all(itname ##uint64, itype, uint64_t) \ + instantiate_copy_base(itname ##uint64, itype, uint64_t) \ instantiate_copy_all(itname ##int8, itype, int8_t) \ instantiate_copy_all(itname ##int16, itype, int16_t) \ instantiate_copy_all(itname ##int32, itype, int32_t) \ - instantiate_copy_all(itname ##int64, itype, int64_t) \ + instantiate_copy_base(itname ##int64, itype, int64_t) \ instantiate_copy_all(itname ##float16, itype, half) \ instantiate_copy_all(itname ##float32, itype, float) \ instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \ - instantiate_copy_all(itname ##complex64, itype, complex64_t) + instantiate_copy_base(itname ##complex64, itype, complex64_t) instantiate_copy_itype(bool_, bool) instantiate_copy_itype(uint8, uint8_t) diff --git a/mlx/backend/metal/kernels/ternary.h b/mlx/backend/metal/kernels/ternary.h index 5251dc7e9..570f5e4d6 100644 --- a/mlx/backend/metal/kernels/ternary.h +++ b/mlx/backend/metal/kernels/ternary.h @@ -9,8 +9,14 @@ template ::n> constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + } } } @@ -23,9 +29,15 @@ template ::n> constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + } } } diff --git a/mlx/backend/metal/kernels/ternary.metal b/mlx/backend/metal/kernels/ternary.metal index cceb53061..6da258b6f 100644 --- a/mlx/backend/metal/kernels/ternary.metal +++ b/mlx/backend/metal/kernels/ternary.metal @@ -8,8 +8,8 @@ #include "mlx/backend/metal/kernels/ternary_ops.h" #include "mlx/backend/metal/kernels/ternary.h" -#define instantiate_ternary_all(op, tname, type) \ - instantiate_kernel("v_" #op #tname, ternary_v, type, op) \ +#define instantiate_ternary_base(op, tname, type) \ + instantiate_kernel("v_" #op #tname, ternary_v, type, op, 1) \ instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \ instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \ instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \ @@ -20,19 +20,23 @@ instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \ instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \ +#define instantiate_ternary_all(op, tname, type) \ + instantiate_kernel("vn_" #op #tname, ternary_v, type, op) \ + instantiate_ternary_base(op, tname, type) + #define instantiate_ternary_types(op) \ instantiate_ternary_all(op, bool_, bool) \ instantiate_ternary_all(op, uint8, uint8_t) \ instantiate_ternary_all(op, uint16, uint16_t) \ instantiate_ternary_all(op, uint32, uint32_t) \ - instantiate_ternary_all(op, uint64, uint64_t) \ + instantiate_ternary_base(op, uint64, uint64_t) \ instantiate_ternary_all(op, int8, int8_t) \ instantiate_ternary_all(op, int16, int16_t) \ instantiate_ternary_all(op, int32, int32_t) \ - instantiate_ternary_all(op, int64, int64_t) \ + instantiate_ternary_base(op, int64, int64_t) \ instantiate_ternary_all(op, float16, half) \ instantiate_ternary_all(op, float32, float) \ instantiate_ternary_all(op, bfloat16, bfloat16_t) \ - instantiate_ternary_all(op, complex64, complex64_t) // clang-format on + instantiate_ternary_base(op, complex64, complex64_t) // clang-format on instantiate_ternary_types(Select) diff --git a/mlx/backend/metal/kernels/unary.h b/mlx/backend/metal/kernels/unary.h index b5eaab2e9..649ba7f2c 100644 --- a/mlx/backend/metal/kernels/unary.h +++ b/mlx/backend/metal/kernels/unary.h @@ -7,8 +7,14 @@ template ::n> constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - out[index + i] = Op()(in[index + i]); + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + out[index + i] = Op()(in[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + out[index + i] = Op()(in[index + i]); + } } } @@ -19,9 +25,15 @@ template ::n> constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - out[offset + i] = Op()(in[offset + i]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + out[offset + i] = Op()(in[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + out[offset + i] = Op()(in[offset + i]); + } } } diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index afced7eb7..160ef4af1 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -5,31 +5,41 @@ #include "mlx/backend/metal/kernels/unary_ops.h" #include "mlx/backend/metal/kernels/unary.h" -#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \ - instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \ - instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \ - instantiate_kernel( \ - "gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \ - instantiate_kernel( \ +#define instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type) \ + instantiate_kernel("vn_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) + +#define instantiate_unary_base(op, in_tname, out_tname, in_type, out_type) \ + instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op, 1) \ + instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \ + instantiate_kernel( \ + "gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \ + instantiate_kernel( \ "gn4large_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4) +#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \ + instantiate_unary_base(op, in_tname, out_tname, in_type, out_type) \ + instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type) + #define instantiate_unary_all_same(op, tname, type) \ instantiate_unary_all(op, tname, tname, type, type) +#define instantiate_unary_base_same(op, tname, type) \ + instantiate_unary_base(op, tname, tname, type, type) + #define instantiate_unary_float(op) \ instantiate_unary_all_same(op, float16, half) \ instantiate_unary_all_same(op, float32, float) \ instantiate_unary_all_same(op, bfloat16, bfloat16_t) -#define instantiate_unary_int(op) \ - instantiate_unary_all_same(op, uint8, uint8_t) \ - instantiate_unary_all_same(op, uint16, uint16_t) \ - instantiate_unary_all_same(op, uint32, uint32_t) \ - instantiate_unary_all_same(op, uint64, uint64_t) \ - instantiate_unary_all_same(op, int8, int8_t) \ - instantiate_unary_all_same(op, int16, int16_t) \ - instantiate_unary_all_same(op, int32, int32_t) \ - instantiate_unary_all_same(op, int64, int64_t) +#define instantiate_unary_int(op) \ + instantiate_unary_all_same(op, uint8, uint8_t) \ + instantiate_unary_all_same(op, uint16, uint16_t) \ + instantiate_unary_all_same(op, uint32, uint32_t) \ + instantiate_unary_base_same(op, uint64, uint64_t) \ + instantiate_unary_all_same(op, int8, int8_t) \ + instantiate_unary_all_same(op, int16, int16_t) \ + instantiate_unary_all_same(op, int32, int32_t) \ + instantiate_unary_base_same(op, int64, int64_t) #define instantiate_unary_types(op) \ instantiate_unary_all_same(op, bool_, bool) \ @@ -68,29 +78,29 @@ instantiate_unary_float(Tanh) instantiate_unary_float(Round) instantiate_unary_int(BitwiseInvert) -instantiate_unary_all_same(Abs, complex64, complex64_t) -instantiate_unary_all_same(ArcCos, complex64, complex64_t) -instantiate_unary_all_same(ArcSin, complex64, complex64_t) -instantiate_unary_all_same(ArcTan, complex64, complex64_t) -instantiate_unary_all_same(Conjugate, complex64, complex64_t) -instantiate_unary_all_same(Cos, complex64, complex64_t) -instantiate_unary_all_same(Cosh, complex64, complex64_t) -instantiate_unary_all_same(Exp, complex64, complex64_t) -instantiate_unary_all_same(Log, complex64, complex64_t) -instantiate_unary_all_same(Log1p, complex64, complex64_t) -instantiate_unary_all_same(Log2, complex64, complex64_t) -instantiate_unary_all_same(Log10, complex64, complex64_t) -instantiate_unary_all_same(Negative, complex64, complex64_t) -instantiate_unary_all_same(Sign, complex64, complex64_t) -instantiate_unary_all_same(Sin, complex64, complex64_t) -instantiate_unary_all_same(Sinh, complex64, complex64_t) -instantiate_unary_all_same(Square, complex64, complex64_t) -instantiate_unary_all_same(Sqrt, complex64, complex64_t) -instantiate_unary_all_same(Rsqrt, complex64, complex64_t) -instantiate_unary_all_same(Tan, complex64, complex64_t) -instantiate_unary_all_same(Tanh, complex64, complex64_t) -instantiate_unary_all_same(Round, complex64, complex64_t) -instantiate_unary_all(Real, complex64, float32, complex64_t, float) -instantiate_unary_all(Imag, complex64, float32, complex64_t, float) +instantiate_unary_base_same(Abs, complex64, complex64_t) +instantiate_unary_base_same(ArcCos, complex64, complex64_t) +instantiate_unary_base_same(ArcSin, complex64, complex64_t) +instantiate_unary_base_same(ArcTan, complex64, complex64_t) +instantiate_unary_base_same(Conjugate, complex64, complex64_t) +instantiate_unary_base_same(Cos, complex64, complex64_t) +instantiate_unary_base_same(Cosh, complex64, complex64_t) +instantiate_unary_base_same(Exp, complex64, complex64_t) +instantiate_unary_base_same(Log, complex64, complex64_t) +instantiate_unary_base_same(Log1p, complex64, complex64_t) +instantiate_unary_base_same(Log2, complex64, complex64_t) +instantiate_unary_base_same(Log10, complex64, complex64_t) +instantiate_unary_base_same(Negative, complex64, complex64_t) +instantiate_unary_base_same(Sign, complex64, complex64_t) +instantiate_unary_base_same(Sin, complex64, complex64_t) +instantiate_unary_base_same(Sinh, complex64, complex64_t) +instantiate_unary_base_same(Square, complex64, complex64_t) +instantiate_unary_base_same(Sqrt, complex64, complex64_t) +instantiate_unary_base_same(Rsqrt, complex64, complex64_t) +instantiate_unary_base_same(Tan, complex64, complex64_t) +instantiate_unary_base_same(Tanh, complex64, complex64_t) +instantiate_unary_base_same(Round, complex64, complex64_t) +instantiate_unary_base(Real, complex64, float32, complex64_t, float) +instantiate_unary_base(Imag, complex64, float32, complex64_t, float) instantiate_unary_all_same(LogicalNot, bool_, bool) // clang-format on diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index 0b821151e..22f2a1985 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -45,7 +45,7 @@ void ternary_op_gpu_inplace( work_per_thread = large ? 4 : 2; } else { large = out.data_size() > INT32_MAX; - work_per_thread = get_work_per_thread(b.dtype()); + work_per_thread = get_work_per_thread(b.dtype(), out.data_size()); } std::string kernel_name; if (topt == TernaryOpType::General) { @@ -60,6 +60,8 @@ void ternary_op_gpu_inplace( } } else if (large) { kernel_name = "v2"; + } else if (work_per_thread > 1) { + kernel_name = "vn"; } else { kernel_name = "v"; } diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index 368e693a9..850c17376 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -43,8 +43,8 @@ void unary_op_gpu_inplace( int work_per_thread; std::string kernel_name; if (contig) { - work_per_thread = get_work_per_thread(in.dtype()); - kernel_name = (large ? "v2" : "v"); + work_per_thread = get_work_per_thread(in.dtype(), in.data_size()); + kernel_name = (large ? "v2" : (work_per_thread > 1 ? "vn" : "v")); } else { work_per_thread = large ? 4 : 1; kernel_name = "gn" + std::to_string(work_per_thread); diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index 576fb9107..a491521a0 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -72,6 +72,10 @@ void concatenate(std::string& acc, T first, Args... args) { inline int get_work_per_thread(Dtype dtype) { return std::max(1, 8 / dtype.size()); } +inline int get_work_per_thread(Dtype dtype, size_t size) { + constexpr size_t wpt_threshold = 1 << 16; + return size < wpt_threshold ? 1 : std::max(1, 8 / dtype.size()); +} inline size_t ceildiv(size_t n, size_t m) { return (n + m - 1) / m;