From 6ccfa603cdb21a1d27f1e06de02985c62d25def5 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 15 Sep 2025 11:01:57 -0700 Subject: [PATCH] fix metal scan (#2591) --- mlx/backend/metal/kernels/scan.h | 2 ++ mlx/backend/metal/scan.cpp | 27 ++++++++++++++++----------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/mlx/backend/metal/kernels/scan.h b/mlx/backend/metal/kernels/scan.h index cb5147558..16682613b 100644 --- a/mlx/backend/metal/kernels/scan.h +++ b/mlx/backend/metal/kernels/scan.h @@ -306,6 +306,7 @@ template < U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]); // Write simdgroup_sums to SM + threadgroup_barrier(mem_flags::mem_threadgroup); if (simd_lane_id == simd_size - 1) { simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]); } @@ -440,6 +441,7 @@ template < } // Read in SM + threadgroup_barrier(mem_flags::mem_threadgroup); if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { for (int i = 0; i < N_READS; i++) { read_into[i] = in[index_y * stride + i]; diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index fd1899108..ca208a36c 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -36,14 +36,6 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { bool contiguous = in.strides()[axis_] == 1; - std::ostringstream kname; - kname << (contiguous ? "contig_" : "strided_"); - kname << "scan_"; - if (reverse_) { - kname << "reverse_"; - } - kname << ((inclusive_) ? "inclusive_" : "exclusive_"); - std::string reduce_type; switch (reduce_type_) { case Scan::Sum: @@ -62,9 +54,22 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { reduce_type = "logaddexp"; break; } - kname << reduce_type << "_" << type_to_name(in) << "_" << type_to_name(out); - auto kernel = get_scan_kernel( - d, kname.str(), reverse_, inclusive_, reduce_type, in, out); + + std::string kname; + concatenate( + kname, + contiguous ? "contig_" : "strided_", + "scan_", + reverse_ ? "reverse_" : "", + (inclusive_) ? "inclusive_" : "exclusive_", + reduce_type, + "_", + type_to_name(in), + "_", + type_to_name(out)); + + auto kernel = + get_scan_kernel(d, kname, reverse_, inclusive_, reduce_type, in, out); if (contiguous) { auto& compute_encoder = d.get_command_encoder(s.index);