Faster metal compiled kernels + some fixes (#1486)

* bump mac tests to use py39

* work per thread for compiled kernels

* fixe for large arrays

* fix
This commit is contained in:
Awni Hannun
2024-10-14 12:45:38 -07:00
committed by GitHub
parent 0eef4febfd
commit 881615b072
12 changed files with 157 additions and 108 deletions

View File

@@ -50,8 +50,6 @@ MTL::ComputePipelineState* get_unary_kernel(
"v_" + lib_name, "unary_v", get_type_string(out_type), op);
kernel_source << get_template_definition(
"v2_" + lib_name, "unary_v2", get_type_string(out_type), op);
kernel_source << get_template_definition(
"g_" + lib_name, "unary_g", get_type_string(out_type), op);
kernel_source << get_template_definition(
"gn4_" + lib_name, "unary_g", get_type_string(out_type), op, 4);
return kernel_source.str();
@@ -65,7 +63,7 @@ void add_binary_kernels(
Dtype out_type,
const std::string op,
std::ostringstream& kernel_source) {
const std::array<std::pair<std::string, std::string>, 11> kernel_types = {{
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
{"ss", "binary_ss"},
{"vs", "binary_vs"},
{"sv", "binary_sv"},
@@ -76,7 +74,6 @@ void add_binary_kernels(
{"g1", "binary_g_nd1"},
{"g2", "binary_g_nd2"},
{"g3", "binary_g_nd3"},
{"gn", "binary_g"},
}};
for (auto& [name, func] : kernel_types) {
std::string template_def;
@@ -138,10 +135,9 @@ MTL::ComputePipelineState* get_ternary_kernel(
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
const std::array<std::pair<std::string, std::string>, 6> kernel_types = {{
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
{"v", "ternary_v"},
{"v2", "ternary_v2"},
{"g", "ternary_g"},
{"g1", "ternary_g_nd1"},
{"g2", "ternary_g_nd2"},
{"g3", "ternary_g_nd3"},
@@ -170,29 +166,27 @@ MTL::ComputePipelineState* get_copy_kernel(
std::ostringstream kernel_source;
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
kernel_source
<< metal::utils() << metal::copy()
<< get_template_definition("s_" + lib_name, "copy_s", in_type, out_type)
<< get_template_definition("v_" + lib_name, "copy_v", in_type, out_type)
<< get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type)
<< get_template_definition(
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
<< get_template_definition(
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
<< get_template_definition("g_" + lib_name, "copy_g", in_type, out_type)
<< get_template_definition(
"gn4_" + lib_name, "copy_g", in_type, out_type, 4)
<< get_template_definition(
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
<< get_template_definition(
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
<< get_template_definition(
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
<< get_template_definition(
"gg_" + lib_name, "copy_gg", in_type, out_type)
<< get_template_definition(
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
kernel_source << metal::utils() << metal::copy()
<< get_template_definition(
"s_" + lib_name, "copy_s", in_type, out_type)
<< get_template_definition(
"v_" + lib_name, "copy_v", in_type, out_type)
<< get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type)
<< get_template_definition(
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
<< get_template_definition(
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
<< get_template_definition(
"gn4_" + lib_name, "copy_g", in_type, out_type, 4)
<< get_template_definition(
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
<< get_template_definition(
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
<< get_template_definition(
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
<< get_template_definition(
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);