Faster triu, tril, where with scalar (#2644)

This commit is contained in:
Awni Hannun
2025-10-02 12:21:27 -07:00
committed by GitHub
parent e88f2d4a8e
commit b0cc71ae71
7 changed files with 101 additions and 42 deletions

View File

@@ -144,8 +144,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
auto t_str = get_type_string(type);
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
const std::array<std::pair<std::string, std::string>, 4> kernel_types = {{
{"v2", "ternary_v2"},
const std::array<std::pair<std::string, std::string>, 3> kernel_types = {{
{"g1large", "ternary_g_nd1"},
{"g2large", "ternary_g_nd2"},
{"g3large", "ternary_g_nd3"},
@@ -154,13 +153,29 @@ MTL::ComputePipelineState* get_ternary_kernel(
kernel_source +=
get_template_definition(name + "_" + lib_name, func, t_str, op);
}
kernel_source += get_template_definition(
"v2_" + lib_name, "ternary_v2", t_str, op, false, false);
kernel_source += get_template_definition(
"sv2_" + lib_name, "ternary_v2", t_str, op, true, false);
kernel_source += get_template_definition(
"vs2_" + lib_name, "ternary_v2", t_str, op, false, true);
if (get_work_per_thread(type) > 1) {
kernel_source +=
get_template_definition("vn_" + lib_name, "ternary_v", t_str, op);
kernel_source += get_template_definition(
"vn_" + lib_name, "ternary_v", t_str, op, false, false);
kernel_source += get_template_definition(
"svn_" + lib_name, "ternary_v", t_str, op, true, false);
kernel_source += get_template_definition(
"vsn_" + lib_name, "ternary_v", t_str, op, false, true);
}
kernel_source +=
get_template_definition("v_" + lib_name, "ternary_v", t_str, op, 1);
kernel_source += get_template_definition(
"v_" + lib_name, "ternary_v", t_str, op, false, false, 1);
kernel_source += get_template_definition(
"sv_" + lib_name, "ternary_v", t_str, op, true, false, 1);
kernel_source += get_template_definition(
"vs_" + lib_name, "ternary_v", t_str, op, false, true, 1);
kernel_source += get_template_definition(
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
kernel_source += get_template_definition(

View File

@@ -1,6 +1,11 @@
// Copyright © 2024 Apple Inc.
template <typename T, typename Op, int N = WorkPerThread<T>::n>
template <
typename T,
typename Op,
bool BSCALAR,
bool CSCALAR,
int N = WorkPerThread<T>::n>
[[kernel]] void ternary_v(
device const bool* a,
device const T* b,
@@ -11,16 +16,25 @@ template <typename T, typename Op, int N = WorkPerThread<T>::n>
index *= N;
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
auto bidx = BSCALAR ? 0 : index + i;
auto cidx = CSCALAR ? 0 : index + i;
d[index + i] = Op()(a[index + i], b[bidx], c[cidx]);
}
} else {
for (int i = 0; i < N; ++i) {
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
auto bidx = BSCALAR ? 0 : index + i;
auto cidx = CSCALAR ? 0 : index + i;
d[index + i] = Op()(a[index + i], b[bidx], c[cidx]);
}
}
}
template <typename T, typename Op, int N = WorkPerThread<T>::n>
template <
typename T,
typename Op,
bool BSCALAR,
bool CSCALAR,
int N = WorkPerThread<T>::n>
[[kernel]] void ternary_v2(
device const bool* a,
device const T* b,
@@ -32,11 +46,15 @@ template <typename T, typename Op, int N = WorkPerThread<T>::n>
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
auto bidx = BSCALAR ? 0 : offset + i;
auto cidx = CSCALAR ? 0 : offset + i;
d[offset + i] = Op()(a[offset + i], b[bidx], c[cidx]);
}
} else {
for (int i = 0; i < N; ++i) {
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
auto bidx = BSCALAR ? 0 : offset + i;
auto cidx = CSCALAR ? 0 : offset + i;
d[offset + i] = Op()(a[offset + i], b[bidx], c[cidx]);
}
}
}

View File

@@ -9,8 +9,12 @@
#include "mlx/backend/metal/kernels/ternary.h"
#define instantiate_ternary_base(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op, 1) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op, false, false, 1) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op, false, false) \
instantiate_kernel("vs_" #op #tname, ternary_v, type, op, false, true, 1) \
instantiate_kernel("vs2_" #op #tname, ternary_v2, type, op, false, true) \
instantiate_kernel("sv_" #op #tname, ternary_v, type, op, true, false, 1) \
instantiate_kernel("sv2_" #op #tname, ternary_v2, type, op, true, false) \
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, int) \
@@ -21,7 +25,9 @@
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("vn_" #op #tname, ternary_v, type, op) \
instantiate_kernel("vn_" #op #tname, ternary_v, type, op, false, false) \
instantiate_kernel("vsn_" #op #tname, ternary_v, type, op, false, true) \
instantiate_kernel("svn_" #op #tname, ternary_v, type, op, true, false) \
instantiate_ternary_base(op, tname, type)
#define instantiate_ternary_types(op) \

View File

@@ -58,12 +58,19 @@ void ternary_op_gpu_inplace(
if (large) {
kernel_name += "large";
}
} else if (large) {
kernel_name = "v2";
} else if (work_per_thread > 1) {
kernel_name = "vn";
} else {
kernel_name = "v";
if (topt == TernaryOpType::VectorScalarVector) {
kernel_name = "sv";
} else if (topt == TernaryOpType::VectorVectorScalar) {
kernel_name = "vs";
} else {
kernel_name = "v";
}
if (large) {
kernel_name += "2";
} else if (work_per_thread > 1) {
kernel_name += "n";
}
}
concatenate(kernel_name, "_", op, type_to_name(b));