mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Dynamic slicing (#1741)
* dynamic slice and slice update * python bindings + tests + fix set item * fix compile issue * comment * fix jit
This commit is contained in:
@@ -218,6 +218,38 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_dynamic_copy_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source = metal::utils();
|
||||
kernel_source += metal::copy();
|
||||
auto in_type = get_type_string(in.dtype());
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
kernel_source += get_template_definition(
|
||||
"gg1_" + lib_name, "copy_gg_dynamic_nd1", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg2_" + lib_name, "copy_gg_dynamic_nd2", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg3_" + lib_name, "copy_gg_dynamic_nd3", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"ggn2_" + lib_name, "copy_gg_dynamic", in_type, out_type, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg1large_" + lib_name, "copy_gg_dynamic_nd1", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"gg2large_" + lib_name, "copy_gg_dynamic_nd2", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"gg3large_" + lib_name, "copy_gg_dynamic_nd3", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"ggn4large_" + lib_name, "copy_gg_dynamic", in_type, out_type, 4);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_softmax_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
||||
Reference in New Issue
Block a user