Faster Metal unary and binary for general case (#1431)

* faster unary and binary for general case

* update ternary + jit fix

* fix jit

* unary work per thread
This commit is contained in:
Awni Hannun
2024-09-25 12:07:43 -07:00
committed by GitHub
parent afc9c0ec1b
commit 4f9f9ebb6f
12 changed files with 183 additions and 93 deletions

View File

@@ -42,18 +42,19 @@ MTL::ComputePipelineState* get_unary_kernel(
const std::string& kernel_name,
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(1);
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
auto u_def = get_template_definition(
"v" + lib_name, "unary_v", get_type_string(out_type), op);
auto u2_def = get_template_definition(
"v2" + lib_name, "unary_v2", get_type_string(out_type), op);
auto g_def = get_template_definition(
"g" + lib_name, "unary_g", get_type_string(out_type), op);
kernel_source << metal::utils() << metal::unary_ops() << metal::unary()
<< u_def << u2_def << g_def;
kernel_source << metal::utils() << metal::unary_ops() << metal::unary();
kernel_source << get_template_definition(
"v_" + lib_name, "unary_v", get_type_string(out_type), op);
kernel_source << get_template_definition(
"v2_" + lib_name, "unary_v2", get_type_string(out_type), op);
kernel_source << get_template_definition(
"g_" + lib_name, "unary_g", get_type_string(out_type), op);
kernel_source << get_template_definition(
"gn4_" + lib_name, "unary_g", get_type_string(out_type), op, 4);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
@@ -81,13 +82,20 @@ void add_binary_kernels(
for (auto& [name, func] : kernel_types) {
std::string template_def;
template_def = get_template_definition(
name + lib_name,
name + "_" + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op);
kernel_source << template_def;
}
kernel_source << get_template_definition(
"gn4_" + lib_name,
"binary_g",
get_type_string(in_type),
get_type_string(out_type),
op,
4);
}
MTL::ComputePipelineState* get_binary_kernel(
@@ -96,7 +104,7 @@ MTL::ComputePipelineState* get_binary_kernel(
Dtype in_type,
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(2);
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
@@ -113,7 +121,7 @@ MTL::ComputePipelineState* get_binary_two_kernel(
Dtype in_type,
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(2);
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
@@ -149,6 +157,8 @@ MTL::ComputePipelineState* get_ternary_kernel(
name + "_" + lib_name, func, get_type_string(type), op);
kernel_source << template_def;
}
kernel_source << get_template_definition(
"gn4_" + lib_name, "ternary_g", get_type_string(type), op, 4);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);