From e769fcca6041e7669f7afb9f619824563f405b11 Mon Sep 17 00:00:00 2001 From: Cheng Date: Tue, 8 Jul 2025 23:21:07 +0000 Subject: [PATCH] Contiguous scan --- mlx/backend/cuda/CMakeLists.txt | 6 + mlx/backend/cuda/primitives.cu | 1 - mlx/backend/cuda/reduce/reduce_utils.cuh | 1 + mlx/backend/cuda/scan.cu | 312 +++++++++++++++++++++++ 4 files changed, 319 insertions(+), 1 deletion(-) create mode 100644 mlx/backend/cuda/scan.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 8130d396f..87f4cb4ae 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -35,6 +35,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu + ${CMAKE_CURRENT_SOURCE_DIR}/scan.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu @@ -67,6 +68,11 @@ target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen") target_compile_options(mlx PRIVATE "$<$:--extended-lambda>") +# Enable calling host constexpr functions from device. This is needed because +# the constexpr version of isnan is host only. +target_compile_options( + mlx PRIVATE "$<$:--expt-relaxed-constexpr>") + # CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive. # Explicitly pass this flag to suppress the warning, it is safe to set it to # true but the warning wouldn't be suppressed. diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index a8496b958..3a3f8ff54 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -82,7 +82,6 @@ NO_GPU(Load) NO_GPU_MULTI(LUF) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) -NO_GPU(Scan) NO_GPU(SegmentedMM) NO_GPU_MULTI(SVD) NO_GPU(Inverse) diff --git a/mlx/backend/cuda/reduce/reduce_utils.cuh b/mlx/backend/cuda/reduce/reduce_utils.cuh index ccd7ae48d..d993bacbb 100644 --- a/mlx/backend/cuda/reduce/reduce_utils.cuh +++ b/mlx/backend/cuda/reduce/reduce_utils.cuh @@ -4,6 +4,7 @@ #include +#include "mlx/backend/common/utils.h" #include "mlx/backend/cuda/device/utils.cuh" #include diff --git a/mlx/backend/cuda/scan.cu b/mlx/backend/cuda/scan.cu new file mode 100644 index 000000000..5d70d7ba6 --- /dev/null +++ b/mlx/backend/cuda/scan.cu @@ -0,0 +1,312 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/binary_ops.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/reduce/reduce_ops.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +struct ScanResult { + using type = T; +}; + +template <> +struct ScanResult { + using type = int32_t; +}; + +template +struct ReduceInit { + static constexpr __host__ __device__ T value() { + return Limits::min(); + } +}; + +template +inline __device__ void +load_vals(int index, const T* in, U (&vals)[N_READS], int size, U init) { + int remaining = size - index * N_READS; + if constexpr (reverse) { + in += remaining - N_READS; + if (remaining < N_READS) { + for (int i = 0; i < N_READS; ++i) { + vals[N_READS - i - 1] = + (N_READS - i - 1 < remaining) ? cast_to(in[i]) : init; + } + } else { + for (int i = 0; i < N_READS; ++i) { + vals[N_READS - i - 1] = cast_to(in[i]); + } + } + } else { + in += index * N_READS; + if (remaining < N_READS) { + for (int i = 0; i < N_READS; ++i) { + vals[i] = (i < remaining) ? cast_to(in[i]) : init; + } + } else { + for (int i = 0; i < N_READS; ++i) { + vals[i] = cast_to(in[i]); + } + } + } +} + +template +inline __device__ void +store_vals(int index, T* out, T (&vals)[N_READS], int size, int offset = 0) { + int start = index * N_READS + offset; + int remaining = size - start; + if constexpr (reverse) { + out += remaining - N_READS; + if (remaining < N_READS) { + for (int i = 0; i < N_READS; ++i) { + if (N_READS - i - 1 < remaining) { + out[i] = vals[N_READS - i - 1]; + } + } + } else { + for (int i = 0; i < N_READS; ++i) { + out[i] = vals[N_READS - i - 1]; + } + } + } else { + out += start; + if (remaining < N_READS) { + for (int i = 0; i < N_READS; ++i) { + if (i < remaining) { + out[i] = vals[i]; + } + } + } else { + for (int i = 0; i < N_READS; ++i) { + out[i] = vals[i]; + } + } + } +} + +template < + typename T, + typename U, + typename Op, + int N_READS, + bool inclusive, + bool reverse> +__global__ void contiguous_scan(const T* in, U* out, int32_t axis_size) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + in += grid.block_rank() * axis_size; + out += grid.block_rank() * axis_size; + + __shared__ U warp_sums[WARP_SIZE]; + + Op op; + U init = ReduceInit::value(); + U prefix = init; + + // Scan per block. + for (int r = 0; r < cuda::ceil_div(axis_size, block.size() * N_READS); ++r) { + int32_t index = r * block.size() + block.thread_rank(); + U vals[N_READS]; + load_vals(index, in, vals, axis_size, init); + + // Compute an inclusive scan per thread. + for (int i = 1; i < N_READS; i++) { + vals[i] = op(vals[i], vals[i - 1]); + } + + // Compute exclusive scan of thread sums. + U prev_thread_sum = cg::exclusive_scan(warp, vals[N_READS - 1], op); + if (warp.thread_rank() == 0) { + prev_thread_sum = init; + } + + // Write wrap's sum to shared memory. + if (warp.thread_rank() == warp.size() - 1) { + warp_sums[warp.meta_group_rank()] = + op(prev_thread_sum, vals[N_READS - 1]); + } + block.sync(); + + // Compute exclusive scan of warp sums. + if (warp.meta_group_rank() == 0) { + U prev_warp_sum = + cg::exclusive_scan(warp, warp_sums[warp.thread_rank()], op); + if (warp.thread_rank() == 0) { + prev_warp_sum = init; + } + warp_sums[warp.thread_rank()] = prev_warp_sum; + } + block.sync(); + + // Compute the output. + for (int i = 0; i < N_READS; ++i) { + vals[i] = op(vals[i], prefix); + vals[i] = op(vals[i], warp_sums[warp.meta_group_rank()]); + vals[i] = op(vals[i], prev_thread_sum); + } + + // Write the values. + if (inclusive) { + store_vals(index, out, vals, axis_size); + } else { + store_vals(index, out, vals, axis_size, 1); + if (reverse) { + if (block.thread_rank() == 0 && index == 0) { + out[axis_size - 1] = init; + } + } else { + if (block.thread_rank() == 0 && index == 0) { + out[0] = init; + } + } + } + block.sync(); + + // Share the prefix. + if ((warp.meta_group_rank() == warp.meta_group_size() - 1) && + (warp.thread_rank() == warp.size() - 1)) { + warp_sums[0] = vals[N_READS - 1]; + } + block.sync(); + prefix = warp_sums[0]; + } +} + +} // namespace cu + +template +void dispatch_scan_ops(Scan::ReduceType scan_op, F&& f) { + if (scan_op == Scan::ReduceType::Max) { + f(type_identity{}); + } else if (scan_op == Scan::ReduceType::Min) { + f(type_identity{}); + } else if (scan_op == Scan::ReduceType::Sum) { + f(type_identity{}); + } else if (scan_op == Scan::ReduceType::Prod) { + f(type_identity{}); + } else if (scan_op == Scan::ReduceType::LogAddExp) { + f(type_identity{}); + } else { + throw std::invalid_argument("Unknown reduce type."); + } +} + +template +const char* op_to_string() { + if (cuda::std::is_same_v) { + return "Max"; + } else if (cuda::std::is_same_v) { + return "Min"; + } else if (cuda::std::is_same_v) { + return "Sum"; + } else if (cuda::std::is_same_v) { + return "Prod"; + } else if (cuda::std::is_same_v) { + return "LogAddExp"; + } else { + throw std::invalid_argument("Unknown op."); + } +} + +template +constexpr bool supports_scan_op() { + if constexpr (cuda::std::is_same_v) { + return is_inexact_v; + } else { + return true; + } +} + +void Scan::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Scan::eval_gpu"); + assert(inputs.size() == 1); + auto in = inputs[0]; + auto& s = stream(); + + if (in.flags().contiguous && in.strides()[axis_] != 0) { + if (in.is_donatable() && in.itemsize() == out.itemsize()) { + out.copy_shared_buffer(in); + } else { + out.set_data( + allocator::malloc(in.data_size() * out.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + } + } else { + array arr_copy(in.shape(), in.dtype(), nullptr, {}); + copy_gpu(in, arr_copy, CopyType::General, s); + in = std::move(arr_copy); + out.copy_shared_buffer(in); + } + + bool contiguous = in.strides()[axis_] == 1; + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + + dispatch_all_types(in.dtype(), [&](auto type_tag) { + using T = cuda_type_t; + dispatch_scan_ops(reduce_type_, [&](auto scan_op_tag) { + using Op = MLX_GET_TYPE(scan_op_tag); + if constexpr (supports_scan_op) { + using U = typename cu::ScanResult::type; + dispatch_bool(inclusive_, [&](auto inclusive) { + dispatch_bool(reverse_, [&](auto reverse) { + if (contiguous) { + constexpr int N_READS = 4; + auto kernel = cu::contiguous_scan< + T, + U, + Op, + N_READS, + inclusive.value, + reverse.value>; + int32_t axis_size = in.shape(axis_); + int block_dim = cuda::ceil_div(axis_size, N_READS); + block_dim = cuda::ceil_div(block_dim, WARP_SIZE) * WARP_SIZE; + block_dim = std::min(block_dim, WARP_SIZE * WARP_SIZE); + encoder.add_kernel_node( + kernel, + in.data_size() / axis_size, + block_dim, + in.data(), + out.data(), + axis_size); + } else { + throw std::runtime_error("Strided Scan NYI"); + } + }); + }); + } else { + throw std::runtime_error(fmt::format( + "Can not do scan op {} on inputs of {} with result of {}.", + op_to_string(), + dtype_to_string(in.dtype()), + dtype_to_string(out.dtype()))); + } + }); + }); +} + +} // namespace mlx::core