mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -40,18 +40,21 @@ MTL::ComputePipelineState* get_arange_kernel(
|
||||
MTL::ComputePipelineState* get_unary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
auto in_t = get_type_string(in_type);
|
||||
auto out_t = get_type_string(out_type);
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::unary_ops() << metal::unary();
|
||||
kernel_source << get_template_definition(
|
||||
"v_" + lib_name, "unary_v", get_type_string(out_type), op);
|
||||
"v_" + lib_name, "unary_v", in_t, out_t, op);
|
||||
kernel_source << get_template_definition(
|
||||
"v2_" + lib_name, "unary_v2", get_type_string(out_type), op);
|
||||
"v2_" + lib_name, "unary_v2", in_t, out_t, op);
|
||||
kernel_source << get_template_definition(
|
||||
"gn4_" + lib_name, "unary_g", get_type_string(out_type), op, 4);
|
||||
"gn4_" + lib_name, "unary_g", in_t, out_t, op, 4);
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
|
||||
Reference in New Issue
Block a user