mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
fix jit
This commit is contained in:
parent
ba8748b12e
commit
a4a4b46b8d
@ -42,8 +42,10 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
concatenate(kernel_source, metal::unary_ops(), metal::unary());
|
||||
kernel_source +=
|
||||
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op, 1);
|
||||
kernel_source +=
|
||||
get_template_definition("vn_" + lib_name, "unary_v", in_t, out_t, op);
|
||||
if (get_work_per_thread(in_type) > 1) {
|
||||
kernel_source +=
|
||||
get_template_definition("vn_" + lib_name, "unary_v", in_t, out_t, op);
|
||||
}
|
||||
kernel_source +=
|
||||
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
|
||||
kernel_source += get_template_definition(
|
||||
@ -61,11 +63,8 @@ void append_binary_kernels(
|
||||
Dtype out_type,
|
||||
const std::string op,
|
||||
std::string& kernel_source) {
|
||||
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
|
||||
const std::array<std::pair<std::string, std::string>, 7> kernel_types = {{
|
||||
{"ss", "binary_ss"},
|
||||
{"vsn", "binary_vs"},
|
||||
{"svn", "binary_sv"},
|
||||
{"vvn", "binary_vv"},
|
||||
{"vs2", "binary_vs2"},
|
||||
{"sv2", "binary_sv2"},
|
||||
{"vv2", "binary_vv2"},
|
||||
@ -85,7 +84,17 @@ void append_binary_kernels(
|
||||
kernel_source += get_template_definition(
|
||||
"sv_" + lib_name, "binary_sv", in_t, out_t, op, 1);
|
||||
kernel_source += get_template_definition(
|
||||
"sv_" + lib_name, "binary_sv", in_t, out_t, op, 1);
|
||||
"vv_" + lib_name, "binary_vv", in_t, out_t, op, 1);
|
||||
|
||||
if (get_work_per_thread(in_type) > 1) {
|
||||
kernel_source += get_template_definition(
|
||||
"vsn_" + lib_name, "binary_vs", in_t, out_t, op);
|
||||
kernel_source += get_template_definition(
|
||||
"svn_" + lib_name, "binary_sv", in_t, out_t, op);
|
||||
kernel_source += get_template_definition(
|
||||
"vvn_" + lib_name, "binary_vv", in_t, out_t, op);
|
||||
}
|
||||
|
||||
kernel_source += get_template_definition(
|
||||
"g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
@ -141,8 +150,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
auto t_str = get_type_string(type);
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
|
||||
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
|
||||
{"vn", "ternary_v"},
|
||||
const std::array<std::pair<std::string, std::string>, 4> kernel_types = {{
|
||||
{"v2", "ternary_v2"},
|
||||
{"g1large", "ternary_g_nd1"},
|
||||
{"g2large", "ternary_g_nd2"},
|
||||
@ -152,6 +160,11 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
kernel_source +=
|
||||
get_template_definition(name + "_" + lib_name, func, t_str, op);
|
||||
}
|
||||
if (get_work_per_thread(type) > 1) {
|
||||
kernel_source +=
|
||||
get_template_definition("vn_" + lib_name, "ternary_v", t_str, op);
|
||||
}
|
||||
|
||||
kernel_source +=
|
||||
get_template_definition("v_" + lib_name, "ternary_v", t_str, op, 1);
|
||||
kernel_source += get_template_definition(
|
||||
@ -182,18 +195,20 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
kernel_source += get_template_definition(
|
||||
"s_" + lib_name, "copy_s", in_type, out_type, 1);
|
||||
kernel_source +=
|
||||
get_template_definition("sn_" + lib_name, "copy_s", in_type, out_type);
|
||||
kernel_source +=
|
||||
get_template_definition("s2_" + lib_name, "copy_s2", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"v_" + lib_name, "copy_v", in_type, out_type, 1);
|
||||
kernel_source +=
|
||||
get_template_definition("vn_" + lib_name, "copy_v", in_type, out_type);
|
||||
|
||||
kernel_source +=
|
||||
get_template_definition("v2_" + lib_name, "copy_v2", in_type, out_type);
|
||||
|
||||
if (get_work_per_thread(out.dtype()) > 1) {
|
||||
kernel_source += get_template_definition(
|
||||
"sn_" + lib_name, "copy_s", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"vn_" + lib_name, "copy_v", in_type, out_type);
|
||||
}
|
||||
|
||||
kernel_source += get_template_definition(
|
||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
|
Loading…
Reference in New Issue
Block a user