fix jit scan when output doesn't have primitive (#1190)

This commit is contained in:
Awni Hannun 2024-06-06 07:24:58 -07:00 committed by GitHub
parent 496315fe1d
commit 578842954c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 15 additions and 7 deletions

View File

@ -170,11 +170,14 @@ MTL::ComputePipelineState* get_scan_kernel(
const std::string& kernel_name,
bool reverse,
bool inclusive,
const std::string& reduce_type,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::string op_name = "Cum" + reduce_type;
op_name[3] = toupper(op_name[3]);
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::scan()
<< fmt::format(
@ -182,7 +185,7 @@ MTL::ComputePipelineState* get_scan_kernel(
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name(out),
op_name,
inclusive,
reverse);
lib = d.get_library(lib_name, kernel_source.str());

View File

@ -49,6 +49,7 @@ MTL::ComputePipelineState* get_scan_kernel(
const std::string& kernel_name,
bool reverse,
bool inclusive,
const std::string& reduce_type,
const array& in,
const array& out);

View File

@ -63,6 +63,7 @@ MTL::ComputePipelineState* get_scan_kernel(
const std::string& kernel_name,
bool,
bool,
const std::string&,
const array&,
const array&) {
return d.get_kernel(kernel_name);

View File

@ -38,22 +38,25 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
kname << "reverse_";
}
kname << ((inclusive_) ? "inclusive_" : "exclusive_");
std::string reduce_type;
switch (reduce_type_) {
case Scan::Sum:
kname << "sum_";
reduce_type = "sum";
break;
case Scan::Prod:
kname << "prod_";
reduce_type = "prod";
break;
case Scan::Max:
kname << "max_";
reduce_type = "max";
break;
case Scan::Min:
kname << "min_";
reduce_type = "min";
break;
}
kname << type_to_name(in) << "_" << type_to_name(out);
auto kernel = get_scan_kernel(d, kname.str(), reverse_, inclusive_, in, out);
kname << reduce_type << "_" << type_to_name(in) << "_" << type_to_name(out);
auto kernel = get_scan_kernel(
d, kname.str(), reverse_, inclusive_, reduce_type, in, out);
if (contiguous) {
auto& compute_encoder = d.get_command_encoder(s.index);