mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Faster metal compiled kernels + some fixes (#1486)
* bump mac tests to use py39 * work per thread for compiled kernels * fixe for large arrays * fix
This commit is contained in:
@@ -50,8 +50,6 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
"v_" + lib_name, "unary_v", get_type_string(out_type), op);
|
||||
kernel_source << get_template_definition(
|
||||
"v2_" + lib_name, "unary_v2", get_type_string(out_type), op);
|
||||
kernel_source << get_template_definition(
|
||||
"g_" + lib_name, "unary_g", get_type_string(out_type), op);
|
||||
kernel_source << get_template_definition(
|
||||
"gn4_" + lib_name, "unary_g", get_type_string(out_type), op, 4);
|
||||
return kernel_source.str();
|
||||
@@ -65,7 +63,7 @@ void add_binary_kernels(
|
||||
Dtype out_type,
|
||||
const std::string op,
|
||||
std::ostringstream& kernel_source) {
|
||||
const std::array<std::pair<std::string, std::string>, 11> kernel_types = {{
|
||||
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
|
||||
{"ss", "binary_ss"},
|
||||
{"vs", "binary_vs"},
|
||||
{"sv", "binary_sv"},
|
||||
@@ -76,7 +74,6 @@ void add_binary_kernels(
|
||||
{"g1", "binary_g_nd1"},
|
||||
{"g2", "binary_g_nd2"},
|
||||
{"g3", "binary_g_nd3"},
|
||||
{"gn", "binary_g"},
|
||||
}};
|
||||
for (auto& [name, func] : kernel_types) {
|
||||
std::string template_def;
|
||||
@@ -138,10 +135,9 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
const std::array<std::pair<std::string, std::string>, 6> kernel_types = {{
|
||||
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
|
||||
{"v", "ternary_v"},
|
||||
{"v2", "ternary_v2"},
|
||||
{"g", "ternary_g"},
|
||||
{"g1", "ternary_g_nd1"},
|
||||
{"g2", "ternary_g_nd2"},
|
||||
{"g3", "ternary_g_nd3"},
|
||||
@@ -170,29 +166,27 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
std::ostringstream kernel_source;
|
||||
auto in_type = get_type_string(in.dtype());
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
kernel_source
|
||||
<< metal::utils() << metal::copy()
|
||||
<< get_template_definition("s_" + lib_name, "copy_s", in_type, out_type)
|
||||
<< get_template_definition("v_" + lib_name, "copy_v", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
|
||||
<< get_template_definition("g_" + lib_name, "copy_g", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gn4_" + lib_name, "copy_g", in_type, out_type, 4)
|
||||
<< get_template_definition(
|
||||
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gg_" + lib_name, "copy_gg", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
|
||||
kernel_source << metal::utils() << metal::copy()
|
||||
<< get_template_definition(
|
||||
"s_" + lib_name, "copy_s", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"v_" + lib_name, "copy_v", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gn4_" + lib_name, "copy_g", in_type, out_type, 4)
|
||||
<< get_template_definition(
|
||||
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
|
||||
Reference in New Issue
Block a user