Optimization for general ND copies (#1421)

This commit is contained in:
Awni Hannun
2024-09-17 17:59:51 -07:00
committed by GitHub
parent 6af5ca35b2
commit 67b6bf530d
5 changed files with 62 additions and 23 deletions

View File

@@ -176,6 +176,8 @@ MTL::ComputePipelineState* get_copy_kernel(
<< get_template_definition(
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
<< get_template_definition("g_" + lib_name, "copy_g", in_type, out_type)
<< get_template_definition(
"gn4_" + lib_name, "copy_g", in_type, out_type, 4)
<< get_template_definition(
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
<< get_template_definition(
@@ -183,7 +185,9 @@ MTL::ComputePipelineState* get_copy_kernel(
<< get_template_definition(
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
<< get_template_definition(
"gg_" + lib_name, "copy_gg", in_type, out_type);
"gg_" + lib_name, "copy_gg", in_type, out_type)
<< get_template_definition(
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);