fix jit scan when output doesn't have primitive (#1190)

This commit is contained in:
Awni Hannun 2024-06-06 07:24:58 -07:00 committed by GitHub
parent 496315fe1d
commit 578842954c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 15 additions and 7 deletions

View File

@ -170,11 +170,14 @@ MTL::ComputePipelineState* get_scan_kernel(
const std::string& kernel_name, const std::string& kernel_name,
bool reverse, bool reverse,
bool inclusive, bool inclusive,
const std::string& reduce_type,
const array& in, const array& in,
const array& out) { const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name);
if (lib == nullptr) { if (lib == nullptr) {
std::string op_name = "Cum" + reduce_type;
op_name[3] = toupper(op_name[3]);
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::scan() kernel_source << metal::utils() << metal::scan()
<< fmt::format( << fmt::format(
@ -182,7 +185,7 @@ MTL::ComputePipelineState* get_scan_kernel(
lib_name, lib_name,
get_type_string(in.dtype()), get_type_string(in.dtype()),
get_type_string(out.dtype()), get_type_string(out.dtype()),
op_name(out), op_name,
inclusive, inclusive,
reverse); reverse);
lib = d.get_library(lib_name, kernel_source.str()); lib = d.get_library(lib_name, kernel_source.str());

View File

@ -49,6 +49,7 @@ MTL::ComputePipelineState* get_scan_kernel(
const std::string& kernel_name, const std::string& kernel_name,
bool reverse, bool reverse,
bool inclusive, bool inclusive,
const std::string& reduce_type,
const array& in, const array& in,
const array& out); const array& out);

View File

@ -63,6 +63,7 @@ MTL::ComputePipelineState* get_scan_kernel(
const std::string& kernel_name, const std::string& kernel_name,
bool, bool,
bool, bool,
const std::string&,
const array&, const array&,
const array&) { const array&) {
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);

View File

@ -38,22 +38,25 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
kname << "reverse_"; kname << "reverse_";
} }
kname << ((inclusive_) ? "inclusive_" : "exclusive_"); kname << ((inclusive_) ? "inclusive_" : "exclusive_");
std::string reduce_type;
switch (reduce_type_) { switch (reduce_type_) {
case Scan::Sum: case Scan::Sum:
kname << "sum_"; reduce_type = "sum";
break; break;
case Scan::Prod: case Scan::Prod:
kname << "prod_"; reduce_type = "prod";
break; break;
case Scan::Max: case Scan::Max:
kname << "max_"; reduce_type = "max";
break; break;
case Scan::Min: case Scan::Min:
kname << "min_"; reduce_type = "min";
break; break;
} }
kname << type_to_name(in) << "_" << type_to_name(out); kname << reduce_type << "_" << type_to_name(in) << "_" << type_to_name(out);
auto kernel = get_scan_kernel(d, kname.str(), reverse_, inclusive_, in, out); auto kernel = get_scan_kernel(
d, kname.str(), reverse_, inclusive_, reduce_type, in, out);
if (contiguous) { if (contiguous) {
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);