mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Optimization for general ND copies (#1421)
This commit is contained in:
@@ -176,6 +176,8 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
<< get_template_definition(
|
||||
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
|
||||
<< get_template_definition("g_" + lib_name, "copy_g", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gn4_" + lib_name, "copy_g", in_type, out_type, 4)
|
||||
<< get_template_definition(
|
||||
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
@@ -183,7 +185,9 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
<< get_template_definition(
|
||||
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gg_" + lib_name, "copy_gg", in_type, out_type);
|
||||
"gg_" + lib_name, "copy_gg", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
|
||||
Reference in New Issue
Block a user