mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
export custom kernel (#2756)
This commit is contained in:
@@ -57,7 +57,7 @@ std::string build_kernel(
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
|
||||
const std::vector<CustomKernelShapeInfo>& shape_infos) {
|
||||
const std::vector<std::tuple<bool, bool, bool>>& shape_infos) {
|
||||
std::string kernel_source;
|
||||
kernel_source.reserve(header.size() + source.size() + 8192);
|
||||
kernel_source += default_header;
|
||||
@@ -81,17 +81,17 @@ std::string build_kernel(
|
||||
kernel_source += ",\n";
|
||||
// Add input shape, strides and ndim if present in the source
|
||||
if (arr.ndim() > 0) {
|
||||
if (shape_infos[i].shape) {
|
||||
if (std::get<0>(shape_infos[i])) {
|
||||
kernel_source += " const __grid_constant__ Shape ";
|
||||
kernel_source += name;
|
||||
kernel_source += "_shape,\n";
|
||||
}
|
||||
if (shape_infos[i].strides) {
|
||||
if (std::get<1>(shape_infos[i])) {
|
||||
kernel_source += " const __grid_constant__ Strides ";
|
||||
kernel_source += name;
|
||||
kernel_source += "_strides,\n";
|
||||
}
|
||||
if (shape_infos[i].ndim) {
|
||||
if (std::get<2>(shape_infos[i])) {
|
||||
kernel_source += " const __grid_constant__ int ";
|
||||
kernel_source += name;
|
||||
kernel_source += "_ndim,\n";
|
||||
@@ -154,12 +154,12 @@ CustomKernelFunction cuda_kernel(
|
||||
"[custom_kernel] Must specify at least one output.");
|
||||
}
|
||||
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
std::vector<std::tuple<bool, bool, bool>> shape_infos;
|
||||
for (auto& n : input_names) {
|
||||
CustomKernelShapeInfo shape_info;
|
||||
shape_info.shape = source.find(n + "_shape") != std::string::npos;
|
||||
shape_info.strides = source.find(n + "_strides") != std::string::npos;
|
||||
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
|
||||
std::tuple<bool, bool, bool> shape_info;
|
||||
std::get<0>(shape_info) = source.find(n + "_shape") != std::string::npos;
|
||||
std::get<1>(shape_info) = source.find(n + "_strides") != std::string::npos;
|
||||
std::get<2>(shape_info) = source.find(n + "_ndim") != std::string::npos;
|
||||
shape_infos.push_back(shape_info);
|
||||
}
|
||||
|
||||
@@ -254,8 +254,8 @@ std::vector<array> precompiled_cuda_kernel(
|
||||
std::optional<float> init_value,
|
||||
bool ensure_row_contiguous,
|
||||
StreamOrDevice s) {
|
||||
std::vector<CustomKernelShapeInfo> shape_infos(
|
||||
inputs.size(), CustomKernelShapeInfo{false, false, false});
|
||||
std::vector<std::tuple<bool, bool, bool>> shape_infos(
|
||||
inputs.size(), {false, false, false});
|
||||
return array::make_arrays(
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
@@ -327,13 +327,13 @@ void CustomKernel::eval_gpu(
|
||||
const array& in = checked_inputs[i];
|
||||
auto& shape_info = shape_infos_[i];
|
||||
args.append(in);
|
||||
if (shape_info.shape) {
|
||||
if (std::get<0>(shape_info)) {
|
||||
args.append_ndim(in.shape());
|
||||
}
|
||||
if (shape_info.strides) {
|
||||
if (std::get<1>(shape_info)) {
|
||||
args.append_ndim(in.strides());
|
||||
}
|
||||
if (shape_info.ndim) {
|
||||
if (std::get<2>(shape_info)) {
|
||||
args.append<int32_t>(in.ndim());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user