From 578842954cad109bff721f9f68a510b8f121c79a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 6 Jun 2024 07:24:58 -0700 Subject: [PATCH] fix jit scan when output doesn't have primitive (#1190) --- mlx/backend/metal/jit_kernels.cpp | 5 ++++- mlx/backend/metal/kernels.h | 1 + mlx/backend/metal/nojit_kernels.cpp | 1 + mlx/backend/metal/scan.cpp | 15 +++++++++------ 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 4616f0cd6..5c84105c9 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -170,11 +170,14 @@ MTL::ComputePipelineState* get_scan_kernel( const std::string& kernel_name, bool reverse, bool inclusive, + const std::string& reduce_type, const array& in, const array& out) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name); if (lib == nullptr) { + std::string op_name = "Cum" + reduce_type; + op_name[3] = toupper(op_name[3]); std::ostringstream kernel_source; kernel_source << metal::utils() << metal::scan() << fmt::format( @@ -182,7 +185,7 @@ MTL::ComputePipelineState* get_scan_kernel( lib_name, get_type_string(in.dtype()), get_type_string(out.dtype()), - op_name(out), + op_name, inclusive, reverse); lib = d.get_library(lib_name, kernel_source.str()); diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index ed1181e4a..94bf609f3 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -49,6 +49,7 @@ MTL::ComputePipelineState* get_scan_kernel( const std::string& kernel_name, bool reverse, bool inclusive, + const std::string& reduce_type, const array& in, const array& out); diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index f91db5a53..4db547e36 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -63,6 +63,7 @@ MTL::ComputePipelineState* get_scan_kernel( const std::string& kernel_name, bool, bool, + const std::string&, const array&, const array&) { return d.get_kernel(kernel_name); diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index 63ca132ff..00596c687 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -38,22 +38,25 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { kname << "reverse_"; } kname << ((inclusive_) ? "inclusive_" : "exclusive_"); + + std::string reduce_type; switch (reduce_type_) { case Scan::Sum: - kname << "sum_"; + reduce_type = "sum"; break; case Scan::Prod: - kname << "prod_"; + reduce_type = "prod"; break; case Scan::Max: - kname << "max_"; + reduce_type = "max"; break; case Scan::Min: - kname << "min_"; + reduce_type = "min"; break; } - kname << type_to_name(in) << "_" << type_to_name(out); - auto kernel = get_scan_kernel(d, kname.str(), reverse_, inclusive_, in, out); + kname << reduce_type << "_" << type_to_name(in) << "_" << type_to_name(out); + auto kernel = get_scan_kernel( + d, kname.str(), reverse_, inclusive_, reduce_type, in, out); if (contiguous) { auto& compute_encoder = d.get_command_encoder(s.index);