From d0ebd18d7d594b11a297e4e48a5396574b770892 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 6 Jun 2025 07:01:03 -0700 Subject: [PATCH] improve metal elementwise kernels --- mlx/backend/metal/binary.cpp | 15 ++- mlx/backend/metal/jit_kernels.cpp | 20 +++- mlx/backend/metal/kernels/binary.h | 60 +++++++++--- mlx/backend/metal/kernels/binary.metal | 47 ++++++---- mlx/backend/metal/kernels/binary_two.h | 102 +++++++++++++++------ mlx/backend/metal/kernels/binary_two.metal | 39 +++++--- mlx/backend/metal/kernels/ternary.h | 22 ++++- mlx/backend/metal/kernels/ternary.metal | 14 ++- mlx/backend/metal/kernels/unary.h | 22 ++++- mlx/backend/metal/kernels/unary.metal | 88 ++++++++++-------- mlx/backend/metal/ternary.cpp | 4 +- mlx/backend/metal/unary.cpp | 4 +- mlx/backend/metal/utils.h | 4 + 13 files changed, 302 insertions(+), 139 deletions(-) diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index c3c67e4d5..54aaf153c 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -31,13 +31,13 @@ std::string get_kernel_name( kname = "ss"; break; case BinaryOpType::ScalarVector: - kname = (large ? "sv2" : "sv"); + kname = "sv"; break; case BinaryOpType::VectorScalar: - kname = (large ? "vs2" : "vs"); + kname = "vs"; break; case BinaryOpType::VectorVector: - kname = (large ? "vv2" : "vv"); + kname = "vv"; break; case BinaryOpType::General: kname = "g"; @@ -51,6 +51,13 @@ std::string get_kernel_name( } break; } + if (bopt != BinaryOpType::General && bopt != BinaryOpType::ScalarScalar) { + if (large) { + kname += "2"; + } else if (work_per_thread > 1) { + kname += "n"; + } + } concatenate(kname, "_", op, type_to_name(a)); return kname; } @@ -90,7 +97,7 @@ void binary_op_gpu_inplace( work_per_thread = large ? 4 : 2; } else { large = out.data_size() > UINT32_MAX; - work_per_thread = get_work_per_thread(a.dtype()); + work_per_thread = get_work_per_thread(a.dtype(), out.data_size()); } std::string kernel_name = get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 5206c9b54..ec379717f 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -41,7 +41,9 @@ MTL::ComputePipelineState* get_unary_kernel( std::string kernel_source = metal::utils(); concatenate(kernel_source, metal::unary_ops(), metal::unary()); kernel_source += - get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op); + get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op, 1); + kernel_source += + get_template_definition("vn_" + lib_name, "unary_v", in_t, out_t, op); kernel_source += get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op); kernel_source += get_template_definition( @@ -61,9 +63,9 @@ void append_binary_kernels( std::string& kernel_source) { const std::array, 10> kernel_types = {{ {"ss", "binary_ss"}, - {"vs", "binary_vs"}, - {"sv", "binary_sv"}, - {"vv", "binary_vv"}, + {"vsn", "binary_vs"}, + {"svn", "binary_sv"}, + {"vvn", "binary_vv"}, {"vs2", "binary_vs2"}, {"sv2", "binary_sv2"}, {"vv2", "binary_vv2"}, @@ -78,6 +80,12 @@ void append_binary_kernels( kernel_source += get_template_definition(name + "_" + lib_name, func, in_t, out_t, op); } + kernel_source += get_template_definition( + "vs_" + lib_name, "binary_vs", in_t, out_t, op, 1); + kernel_source += get_template_definition( + "sv_" + lib_name, "binary_sv", in_t, out_t, op, 1); + kernel_source += get_template_definition( + "sv_" + lib_name, "binary_sv", in_t, out_t, op, 1); kernel_source += get_template_definition( "g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int"); kernel_source += get_template_definition( @@ -134,7 +142,7 @@ MTL::ComputePipelineState* get_ternary_kernel( std::string kernel_source = metal::utils(); concatenate(kernel_source, metal::ternary_ops(), metal::ternary()); const std::array, 5> kernel_types = {{ - {"v", "ternary_v"}, + {"vn", "ternary_v"}, {"v2", "ternary_v2"}, {"g1large", "ternary_g_nd1"}, {"g2large", "ternary_g_nd2"}, @@ -144,6 +152,8 @@ MTL::ComputePipelineState* get_ternary_kernel( kernel_source += get_template_definition(name + "_" + lib_name, func, t_str, op); } + kernel_source += + get_template_definition("v_" + lib_name, "ternary_v", t_str, op, 1); kernel_source += get_template_definition( "g1_" + lib_name, "ternary_g_nd1", t_str, op, "int"); kernel_source += get_template_definition( diff --git a/mlx/backend/metal/kernels/binary.h b/mlx/backend/metal/kernels/binary.h index ffc33ad82..f1df88535 100644 --- a/mlx/backend/metal/kernels/binary.h +++ b/mlx/backend/metal/kernels/binary.h @@ -17,8 +17,14 @@ template ::n> constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - c[index + i] = Op()(a[0], b[index + i]); + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } } } @@ -30,8 +36,14 @@ template ::n> constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - c[index + i] = Op()(a[index + i], b[0]); + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } } } @@ -43,8 +55,14 @@ template ::n> constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - c[index + i] = Op()(a[index + i], b[index + i]); + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } } } @@ -57,8 +75,14 @@ template ::n> uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - c[offset + i] = Op()(a[0], b[offset + i]); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } } } @@ -71,8 +95,14 @@ template ::n> uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - c[offset + i] = Op()(a[offset + i], b[0]); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } } } @@ -85,8 +115,14 @@ template ::n> uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - c[offset + i] = Op()(a[offset + i], b[offset + i]); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } } } diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index 1d555fefa..17ed13c57 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -9,11 +9,16 @@ #include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary.h" -#define instantiate_binary_all(op, tname, itype, otype) \ +#define instantiate_binary_work_per_thread(op, tname, itype, otype) \ + instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \ + instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \ + instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op) \ + +#define instantiate_binary_base(op, tname, itype, otype) \ instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ - instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ - instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ - instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ + instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \ + instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \ + instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \ instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ @@ -26,15 +31,19 @@ instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \ instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op) -#define instantiate_binary_integer(op) \ - instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ - instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ - instantiate_binary_all(op, uint32, uint32_t, uint32_t) \ - instantiate_binary_all(op, uint64, uint64_t, uint64_t) \ - instantiate_binary_all(op, int8, int8_t, int8_t) \ - instantiate_binary_all(op, int16, int16_t, int16_t) \ - instantiate_binary_all(op, int32, int32_t, int32_t) \ - instantiate_binary_all(op, int64, int64_t, int64_t) +#define instantiate_binary_all(op, tname, itype, otype) \ + instantiate_binary_base(op, tname, itype, otype) \ + instantiate_binary_work_per_thread(op, tname, itype, otype) + +#define instantiate_binary_integer(op) \ + instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ + instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ + instantiate_binary_all(op, uint32, uint32_t, uint32_t) \ + instantiate_binary_base(op, uint64, uint64_t, uint64_t) \ + instantiate_binary_all(op, int8, int8_t, int8_t) \ + instantiate_binary_all(op, int16, int16_t, int16_t) \ + instantiate_binary_all(op, int32, int32_t, int32_t) \ + instantiate_binary_base(op, int64, int64_t, int64_t) #define instantiate_binary_float(op) \ instantiate_binary_all(op, float16, half, half) \ @@ -44,7 +53,7 @@ #define instantiate_binary_types(op) \ instantiate_binary_all(op, bool_, bool, bool) \ instantiate_binary_integer(op) \ - instantiate_binary_all(op, complex64, complex64_t, complex64_t) \ + instantiate_binary_base(op, complex64, complex64_t, complex64_t)\ instantiate_binary_float(op) #define instantiate_binary_types_bool(op) \ @@ -52,15 +61,15 @@ instantiate_binary_all(op, uint8, uint8_t, bool) \ instantiate_binary_all(op, uint16, uint16_t, bool) \ instantiate_binary_all(op, uint32, uint32_t, bool) \ - instantiate_binary_all(op, uint64, uint64_t, bool) \ + instantiate_binary_base(op, uint64, uint64_t, bool) \ instantiate_binary_all(op, int8, int8_t, bool) \ instantiate_binary_all(op, int16, int16_t, bool) \ instantiate_binary_all(op, int32, int32_t, bool) \ - instantiate_binary_all(op, int64, int64_t, bool) \ + instantiate_binary_base(op, int64, int64_t, bool) \ instantiate_binary_all(op, float16, half, bool) \ instantiate_binary_all(op, float32, float, bool) \ instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \ - instantiate_binary_all(op, complex64, complex64_t, bool) + instantiate_binary_base(op, complex64, complex64_t, bool) instantiate_binary_types(Add) instantiate_binary_types(Divide) @@ -71,7 +80,7 @@ instantiate_binary_types_bool(Less) instantiate_binary_types_bool(LessEqual) instantiate_binary_types_bool(NotEqual) instantiate_binary_float(LogAddExp) -instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t) +instantiate_binary_base(LogAddExp, complex64, complex64_t, complex64_t) instantiate_binary_types(Maximum) instantiate_binary_types(Minimum) instantiate_binary_types(Multiply) @@ -84,7 +93,7 @@ instantiate_binary_float(ArcTan2) instantiate_binary_all(NaNEqual, float16, half, bool) instantiate_binary_all(NaNEqual, float32, float, bool) instantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool) -instantiate_binary_all(NaNEqual, complex64, complex64_t, bool) +instantiate_binary_base(NaNEqual, complex64, complex64_t, bool) instantiate_binary_all(LogicalOr, bool_, bool, bool) instantiate_binary_all(LogicalAnd, bool_, bool, bool) diff --git a/mlx/backend/metal/kernels/binary_two.h b/mlx/backend/metal/kernels/binary_two.h index e261d33c4..4455e4ca9 100644 --- a/mlx/backend/metal/kernels/binary_two.h +++ b/mlx/backend/metal/kernels/binary_two.h @@ -21,10 +21,18 @@ template ::n> constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - auto out = Op()(a[0], b[index + i]); - c[index + i] = out[0]; - d[index + i] = out[1]; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } } } @@ -37,10 +45,18 @@ template ::n> constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - auto out = Op()(a[index + i], b[0]); - c[index + i] = out[0]; - d[index + i] = out[1]; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } } } @@ -53,10 +69,18 @@ template ::n> constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - auto out = Op()(a[index + i], b[index + i]); - c[index + i] = out[0]; - d[index + i] = out[1]; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } } } @@ -69,11 +93,19 @@ template ::n> constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - auto out = Op()(a[0], b[offset + i]); - c[offset + i] = out[0]; - d[offset + i] = out[1]; + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } } } @@ -86,11 +118,19 @@ template ::n> constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - auto out = Op()(a[offset + i], b[0]); - c[offset + i] = out[0]; - d[offset + i] = out[1]; + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } } } @@ -103,11 +143,19 @@ template ::n> constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - auto out = Op()(a[offset + i], b[offset + i]); - c[offset + i] = out[0]; - d[offset + i] = out[1]; + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } } } diff --git a/mlx/backend/metal/kernels/binary_two.metal b/mlx/backend/metal/kernels/binary_two.metal index 984a28320..c7d3ecdf0 100644 --- a/mlx/backend/metal/kernels/binary_two.metal +++ b/mlx/backend/metal/kernels/binary_two.metal @@ -7,11 +7,16 @@ #include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary_two.h" -#define instantiate_binary_all(op, tname, itype, otype) \ +#define instantiate_binary_work_per_thread(op, tname, itype, otype) \ + instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \ + instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \ + instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op) + +#define instantiate_binary_base(op, tname, itype, otype) \ instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ - instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ - instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ - instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ + instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \ + instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \ + instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \ instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ @@ -24,22 +29,26 @@ instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \ instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op) +#define instantiate_binary_all(op, tname, itype, otype) \ + instantiate_binary_base(op, tname, itype, otype) \ + instantiate_binary_work_per_thread(op, tname, itype, otype) + #define instantiate_binary_float(op) \ instantiate_binary_all(op, float16, half, half) \ instantiate_binary_all(op, float32, float, float) \ instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t) -#define instantiate_binary_types(op) \ - instantiate_binary_all(op, bool_, bool, bool) \ - instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ - instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ - instantiate_binary_all(op, uint32, uint32_t, uint32_t) \ - instantiate_binary_all(op, uint64, uint64_t, uint64_t) \ - instantiate_binary_all(op, int8, int8_t, int8_t) \ - instantiate_binary_all(op, int16, int16_t, int16_t) \ - instantiate_binary_all(op, int32, int32_t, int32_t) \ - instantiate_binary_all(op, int64, int64_t, int64_t) \ - instantiate_binary_all(op, complex64, complex64_t, complex64_t) \ +#define instantiate_binary_types(op) \ + instantiate_binary_all(op, bool_, bool, bool) \ + instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ + instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ + instantiate_binary_all(op, uint32, uint32_t, uint32_t) \ + instantiate_binary_base(op, uint64, uint64_t, uint64_t) \ + instantiate_binary_all(op, int8, int8_t, int8_t) \ + instantiate_binary_all(op, int16, int16_t, int16_t) \ + instantiate_binary_all(op, int32, int32_t, int32_t) \ + instantiate_binary_base(op, int64, int64_t, int64_t) \ + instantiate_binary_base(op, complex64, complex64_t, complex64_t) \ instantiate_binary_float(op) instantiate_binary_types(DivMod) // clang-format on diff --git a/mlx/backend/metal/kernels/ternary.h b/mlx/backend/metal/kernels/ternary.h index 5251dc7e9..570f5e4d6 100644 --- a/mlx/backend/metal/kernels/ternary.h +++ b/mlx/backend/metal/kernels/ternary.h @@ -9,8 +9,14 @@ template ::n> constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + } } } @@ -23,9 +29,15 @@ template ::n> constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + } } } diff --git a/mlx/backend/metal/kernels/ternary.metal b/mlx/backend/metal/kernels/ternary.metal index cceb53061..6da258b6f 100644 --- a/mlx/backend/metal/kernels/ternary.metal +++ b/mlx/backend/metal/kernels/ternary.metal @@ -8,8 +8,8 @@ #include "mlx/backend/metal/kernels/ternary_ops.h" #include "mlx/backend/metal/kernels/ternary.h" -#define instantiate_ternary_all(op, tname, type) \ - instantiate_kernel("v_" #op #tname, ternary_v, type, op) \ +#define instantiate_ternary_base(op, tname, type) \ + instantiate_kernel("v_" #op #tname, ternary_v, type, op, 1) \ instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \ instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \ instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \ @@ -20,19 +20,23 @@ instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \ instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \ +#define instantiate_ternary_all(op, tname, type) \ + instantiate_kernel("vn_" #op #tname, ternary_v, type, op) \ + instantiate_ternary_base(op, tname, type) + #define instantiate_ternary_types(op) \ instantiate_ternary_all(op, bool_, bool) \ instantiate_ternary_all(op, uint8, uint8_t) \ instantiate_ternary_all(op, uint16, uint16_t) \ instantiate_ternary_all(op, uint32, uint32_t) \ - instantiate_ternary_all(op, uint64, uint64_t) \ + instantiate_ternary_base(op, uint64, uint64_t) \ instantiate_ternary_all(op, int8, int8_t) \ instantiate_ternary_all(op, int16, int16_t) \ instantiate_ternary_all(op, int32, int32_t) \ - instantiate_ternary_all(op, int64, int64_t) \ + instantiate_ternary_base(op, int64, int64_t) \ instantiate_ternary_all(op, float16, half) \ instantiate_ternary_all(op, float32, float) \ instantiate_ternary_all(op, bfloat16, bfloat16_t) \ - instantiate_ternary_all(op, complex64, complex64_t) // clang-format on + instantiate_ternary_base(op, complex64, complex64_t) // clang-format on instantiate_ternary_types(Select) diff --git a/mlx/backend/metal/kernels/unary.h b/mlx/backend/metal/kernels/unary.h index b5eaab2e9..649ba7f2c 100644 --- a/mlx/backend/metal/kernels/unary.h +++ b/mlx/backend/metal/kernels/unary.h @@ -7,8 +7,14 @@ template ::n> constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - out[index + i] = Op()(in[index + i]); + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + out[index + i] = Op()(in[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + out[index + i] = Op()(in[index + i]); + } } } @@ -19,9 +25,15 @@ template ::n> constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - out[offset + i] = Op()(in[offset + i]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + out[offset + i] = Op()(in[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + out[offset + i] = Op()(in[offset + i]); + } } } diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index afced7eb7..160ef4af1 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -5,31 +5,41 @@ #include "mlx/backend/metal/kernels/unary_ops.h" #include "mlx/backend/metal/kernels/unary.h" -#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \ - instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \ - instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \ - instantiate_kernel( \ - "gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \ - instantiate_kernel( \ +#define instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type) \ + instantiate_kernel("vn_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) + +#define instantiate_unary_base(op, in_tname, out_tname, in_type, out_type) \ + instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op, 1) \ + instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \ + instantiate_kernel( \ + "gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \ + instantiate_kernel( \ "gn4large_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4) +#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \ + instantiate_unary_base(op, in_tname, out_tname, in_type, out_type) \ + instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type) + #define instantiate_unary_all_same(op, tname, type) \ instantiate_unary_all(op, tname, tname, type, type) +#define instantiate_unary_base_same(op, tname, type) \ + instantiate_unary_base(op, tname, tname, type, type) + #define instantiate_unary_float(op) \ instantiate_unary_all_same(op, float16, half) \ instantiate_unary_all_same(op, float32, float) \ instantiate_unary_all_same(op, bfloat16, bfloat16_t) -#define instantiate_unary_int(op) \ - instantiate_unary_all_same(op, uint8, uint8_t) \ - instantiate_unary_all_same(op, uint16, uint16_t) \ - instantiate_unary_all_same(op, uint32, uint32_t) \ - instantiate_unary_all_same(op, uint64, uint64_t) \ - instantiate_unary_all_same(op, int8, int8_t) \ - instantiate_unary_all_same(op, int16, int16_t) \ - instantiate_unary_all_same(op, int32, int32_t) \ - instantiate_unary_all_same(op, int64, int64_t) +#define instantiate_unary_int(op) \ + instantiate_unary_all_same(op, uint8, uint8_t) \ + instantiate_unary_all_same(op, uint16, uint16_t) \ + instantiate_unary_all_same(op, uint32, uint32_t) \ + instantiate_unary_base_same(op, uint64, uint64_t) \ + instantiate_unary_all_same(op, int8, int8_t) \ + instantiate_unary_all_same(op, int16, int16_t) \ + instantiate_unary_all_same(op, int32, int32_t) \ + instantiate_unary_base_same(op, int64, int64_t) #define instantiate_unary_types(op) \ instantiate_unary_all_same(op, bool_, bool) \ @@ -68,29 +78,29 @@ instantiate_unary_float(Tanh) instantiate_unary_float(Round) instantiate_unary_int(BitwiseInvert) -instantiate_unary_all_same(Abs, complex64, complex64_t) -instantiate_unary_all_same(ArcCos, complex64, complex64_t) -instantiate_unary_all_same(ArcSin, complex64, complex64_t) -instantiate_unary_all_same(ArcTan, complex64, complex64_t) -instantiate_unary_all_same(Conjugate, complex64, complex64_t) -instantiate_unary_all_same(Cos, complex64, complex64_t) -instantiate_unary_all_same(Cosh, complex64, complex64_t) -instantiate_unary_all_same(Exp, complex64, complex64_t) -instantiate_unary_all_same(Log, complex64, complex64_t) -instantiate_unary_all_same(Log1p, complex64, complex64_t) -instantiate_unary_all_same(Log2, complex64, complex64_t) -instantiate_unary_all_same(Log10, complex64, complex64_t) -instantiate_unary_all_same(Negative, complex64, complex64_t) -instantiate_unary_all_same(Sign, complex64, complex64_t) -instantiate_unary_all_same(Sin, complex64, complex64_t) -instantiate_unary_all_same(Sinh, complex64, complex64_t) -instantiate_unary_all_same(Square, complex64, complex64_t) -instantiate_unary_all_same(Sqrt, complex64, complex64_t) -instantiate_unary_all_same(Rsqrt, complex64, complex64_t) -instantiate_unary_all_same(Tan, complex64, complex64_t) -instantiate_unary_all_same(Tanh, complex64, complex64_t) -instantiate_unary_all_same(Round, complex64, complex64_t) -instantiate_unary_all(Real, complex64, float32, complex64_t, float) -instantiate_unary_all(Imag, complex64, float32, complex64_t, float) +instantiate_unary_base_same(Abs, complex64, complex64_t) +instantiate_unary_base_same(ArcCos, complex64, complex64_t) +instantiate_unary_base_same(ArcSin, complex64, complex64_t) +instantiate_unary_base_same(ArcTan, complex64, complex64_t) +instantiate_unary_base_same(Conjugate, complex64, complex64_t) +instantiate_unary_base_same(Cos, complex64, complex64_t) +instantiate_unary_base_same(Cosh, complex64, complex64_t) +instantiate_unary_base_same(Exp, complex64, complex64_t) +instantiate_unary_base_same(Log, complex64, complex64_t) +instantiate_unary_base_same(Log1p, complex64, complex64_t) +instantiate_unary_base_same(Log2, complex64, complex64_t) +instantiate_unary_base_same(Log10, complex64, complex64_t) +instantiate_unary_base_same(Negative, complex64, complex64_t) +instantiate_unary_base_same(Sign, complex64, complex64_t) +instantiate_unary_base_same(Sin, complex64, complex64_t) +instantiate_unary_base_same(Sinh, complex64, complex64_t) +instantiate_unary_base_same(Square, complex64, complex64_t) +instantiate_unary_base_same(Sqrt, complex64, complex64_t) +instantiate_unary_base_same(Rsqrt, complex64, complex64_t) +instantiate_unary_base_same(Tan, complex64, complex64_t) +instantiate_unary_base_same(Tanh, complex64, complex64_t) +instantiate_unary_base_same(Round, complex64, complex64_t) +instantiate_unary_base(Real, complex64, float32, complex64_t, float) +instantiate_unary_base(Imag, complex64, float32, complex64_t, float) instantiate_unary_all_same(LogicalNot, bool_, bool) // clang-format on diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index 0b821151e..22f2a1985 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -45,7 +45,7 @@ void ternary_op_gpu_inplace( work_per_thread = large ? 4 : 2; } else { large = out.data_size() > INT32_MAX; - work_per_thread = get_work_per_thread(b.dtype()); + work_per_thread = get_work_per_thread(b.dtype(), out.data_size()); } std::string kernel_name; if (topt == TernaryOpType::General) { @@ -60,6 +60,8 @@ void ternary_op_gpu_inplace( } } else if (large) { kernel_name = "v2"; + } else if (work_per_thread > 1) { + kernel_name = "vn"; } else { kernel_name = "v"; } diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index 368e693a9..850c17376 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -43,8 +43,8 @@ void unary_op_gpu_inplace( int work_per_thread; std::string kernel_name; if (contig) { - work_per_thread = get_work_per_thread(in.dtype()); - kernel_name = (large ? "v2" : "v"); + work_per_thread = get_work_per_thread(in.dtype(), in.data_size()); + kernel_name = (large ? "v2" : (work_per_thread > 1 ? "vn" : "v")); } else { work_per_thread = large ? 4 : 1; kernel_name = "gn" + std::to_string(work_per_thread); diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index 576fb9107..a491521a0 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -72,6 +72,10 @@ void concatenate(std::string& acc, T first, Args... args) { inline int get_work_per_thread(Dtype dtype) { return std::max(1, 8 / dtype.size()); } +inline int get_work_per_thread(Dtype dtype, size_t size) { + constexpr size_t wpt_threshold = 1 << 16; + return size < wpt_threshold ? 1 : std::max(1, 8 / dtype.size()); +} inline size_t ceildiv(size_t n, size_t m) { return (n + m - 1) / m;