Improve metal elementwise kernels (#2247)

* improve metal elementwise kernels

* compile and copy

* fix jit
This commit is contained in:
Awni Hannun
2025-06-06 11:37:40 -07:00
committed by GitHub
parent a5ac9244c4
commit c6a20b427a
17 changed files with 412 additions and 174 deletions

View File

@@ -41,7 +41,11 @@ MTL::ComputePipelineState* get_unary_kernel(
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::unary_ops(), metal::unary());
kernel_source +=
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op);
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op, 1);
if (get_work_per_thread(in_type) > 1) {
kernel_source +=
get_template_definition("vn_" + lib_name, "unary_v", in_t, out_t, op);
}
kernel_source +=
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
kernel_source += get_template_definition(
@@ -59,11 +63,8 @@ void append_binary_kernels(
Dtype out_type,
const std::string op,
std::string& kernel_source) {
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
const std::array<std::pair<std::string, std::string>, 7> kernel_types = {{
{"ss", "binary_ss"},
{"vs", "binary_vs"},
{"sv", "binary_sv"},
{"vv", "binary_vv"},
{"vs2", "binary_vs2"},
{"sv2", "binary_sv2"},
{"vv2", "binary_vv2"},
@@ -78,6 +79,22 @@ void append_binary_kernels(
kernel_source +=
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
}
kernel_source += get_template_definition(
"vs_" + lib_name, "binary_vs", in_t, out_t, op, 1);
kernel_source += get_template_definition(
"sv_" + lib_name, "binary_sv", in_t, out_t, op, 1);
kernel_source += get_template_definition(
"vv_" + lib_name, "binary_vv", in_t, out_t, op, 1);
if (get_work_per_thread(in_type) > 1) {
kernel_source += get_template_definition(
"vsn_" + lib_name, "binary_vs", in_t, out_t, op);
kernel_source += get_template_definition(
"svn_" + lib_name, "binary_sv", in_t, out_t, op);
kernel_source += get_template_definition(
"vvn_" + lib_name, "binary_vv", in_t, out_t, op);
}
kernel_source += get_template_definition(
"g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int");
kernel_source += get_template_definition(
@@ -133,8 +150,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
auto t_str = get_type_string(type);
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
{"v", "ternary_v"},
const std::array<std::pair<std::string, std::string>, 4> kernel_types = {{
{"v2", "ternary_v2"},
{"g1large", "ternary_g_nd1"},
{"g2large", "ternary_g_nd2"},
@@ -144,6 +160,13 @@ MTL::ComputePipelineState* get_ternary_kernel(
kernel_source +=
get_template_definition(name + "_" + lib_name, func, t_str, op);
}
if (get_work_per_thread(type) > 1) {
kernel_source +=
get_template_definition("vn_" + lib_name, "ternary_v", t_str, op);
}
kernel_source +=
get_template_definition("v_" + lib_name, "ternary_v", t_str, op, 1);
kernel_source += get_template_definition(
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
kernel_source += get_template_definition(
@@ -170,15 +193,22 @@ MTL::ComputePipelineState* get_copy_kernel(
kernel_source += metal::copy();
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
kernel_source +=
get_template_definition("s_" + lib_name, "copy_s", in_type, out_type);
kernel_source += get_template_definition(
"s_" + lib_name, "copy_s", in_type, out_type, 1);
kernel_source +=
get_template_definition("s2_" + lib_name, "copy_s2", in_type, out_type);
kernel_source +=
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type);
kernel_source += get_template_definition(
"v_" + lib_name, "copy_v", in_type, out_type, 1);
kernel_source +=
get_template_definition("v2_" + lib_name, "copy_v2", in_type, out_type);
if (get_work_per_thread(out.dtype()) > 1) {
kernel_source += get_template_definition(
"sn_" + lib_name, "copy_s", in_type, out_type);
kernel_source += get_template_definition(
"vn_" + lib_name, "copy_v", in_type, out_type);
}
kernel_source += get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int");
kernel_source += get_template_definition(