fix jit scan when output doesn't have primitive (#1190)

This commit is contained in:
Awni Hannun
2024-06-06 07:24:58 -07:00
committed by GitHub
parent 496315fe1d
commit 578842954c
4 changed files with 15 additions and 7 deletions

View File

@@ -170,11 +170,14 @@ MTL::ComputePipelineState* get_scan_kernel(
const std::string& kernel_name,
bool reverse,
bool inclusive,
const std::string& reduce_type,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::string op_name = "Cum" + reduce_type;
op_name[3] = toupper(op_name[3]);
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::scan()
<< fmt::format(
@@ -182,7 +185,7 @@ MTL::ComputePipelineState* get_scan_kernel(
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name(out),
op_name,
inclusive,
reverse);
lib = d.get_library(lib_name, kernel_source.str());