mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Allow no copy negative strides in as_strided and slice (#1688)
* allow no copy negative strides in as_strided and slice * fix jit * fix jit
This commit is contained in:
@@ -52,7 +52,7 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
kernel_source +=
|
||||
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
|
||||
kernel_source += get_template_definition(
|
||||
"gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "uint");
|
||||
"gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "unary_g", in_t, out_t, op, 4);
|
||||
return kernel_source;
|
||||
@@ -74,7 +74,7 @@ void append_binary_kernels(
|
||||
{"vs2", "binary_vs2"},
|
||||
{"sv2", "binary_sv2"},
|
||||
{"vv2", "binary_vv2"},
|
||||
{"g1", "binary_g_nd1"},
|
||||
{"g1large", "binary_g_nd1"},
|
||||
{"g2large", "binary_g_nd2"},
|
||||
{"g3large", "binary_g_nd3"},
|
||||
}};
|
||||
@@ -86,11 +86,13 @@ void append_binary_kernels(
|
||||
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
|
||||
}
|
||||
kernel_source += get_template_definition(
|
||||
"g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "uint");
|
||||
"g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "uint");
|
||||
"g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "uint");
|
||||
"g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "binary_g", in_t, out_t, op, 4);
|
||||
}
|
||||
@@ -141,7 +143,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
|
||||
{"v", "ternary_v"},
|
||||
{"v2", "ternary_v2"},
|
||||
{"g1", "ternary_g_nd1"},
|
||||
{"g1large", "ternary_g_nd1"},
|
||||
{"g2large", "ternary_g_nd2"},
|
||||
{"g3large", "ternary_g_nd3"},
|
||||
}};
|
||||
@@ -150,11 +152,13 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
get_template_definition(name + "_" + lib_name, func, t_str, op);
|
||||
}
|
||||
kernel_source += get_template_definition(
|
||||
"g2_" + lib_name, "ternary_g_nd2", t_str, op, "uint");
|
||||
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g3_" + lib_name, "ternary_g_nd3", t_str, op, "uint");
|
||||
"g2_" + lib_name, "ternary_g_nd2", t_str, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "ternary_g", t_str, op, 2, "uint");
|
||||
"g3_" + lib_name, "ternary_g_nd3", t_str, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "ternary_g", t_str, op, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "ternary_g", t_str, op, 4);
|
||||
return kernel_source;
|
||||
@@ -178,7 +182,7 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
kernel_source +=
|
||||
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type);
|
||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g2_" + lib_name, "copy_g_nd2", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
@@ -186,19 +190,23 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "copy_g", in_type, out_type, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type);
|
||||
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"ggn2_" + lib_name, "copy_gg", in_type, out_type, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g1large_" + lib_name, "copy_g_nd1", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"g2large_" + lib_name, "copy_g_nd2", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"g3large_" + lib_name, "copy_g_nd3", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "copy_g", in_type, out_type, 4);
|
||||
kernel_source += get_template_definition(
|
||||
"gg1large_" + lib_name, "copy_gg_nd1", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"gg2large_" + lib_name, "copy_gg_nd2", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
|
||||
Reference in New Issue
Block a user