From b89d8ef1c07b990934b7acb2ad1caf51c8ffba00 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 9 Jul 2025 10:42:05 +0000 Subject: [PATCH] Strided scan --- mlx/backend/cuda/scan.cu | 208 +++++++++++++++++++++++++++++++++------ 1 file changed, 180 insertions(+), 28 deletions(-) diff --git a/mlx/backend/cuda/scan.cu b/mlx/backend/cuda/scan.cu index 5d70d7ba6..198bf2dbb 100644 --- a/mlx/backend/cuda/scan.cu +++ b/mlx/backend/cuda/scan.cu @@ -39,37 +39,37 @@ struct ReduceInit { template inline __device__ void -load_vals(int index, const T* in, U (&vals)[N_READS], int size, U init) { +load_values(int index, const T* in, U (&values)[N_READS], int size, U init) { int remaining = size - index * N_READS; if constexpr (reverse) { in += remaining - N_READS; if (remaining < N_READS) { for (int i = 0; i < N_READS; ++i) { - vals[N_READS - i - 1] = + values[N_READS - i - 1] = (N_READS - i - 1 < remaining) ? cast_to(in[i]) : init; } } else { for (int i = 0; i < N_READS; ++i) { - vals[N_READS - i - 1] = cast_to(in[i]); + values[N_READS - i - 1] = cast_to(in[i]); } } } else { in += index * N_READS; if (remaining < N_READS) { for (int i = 0; i < N_READS; ++i) { - vals[i] = (i < remaining) ? cast_to(in[i]) : init; + values[i] = (i < remaining) ? cast_to(in[i]) : init; } } else { for (int i = 0; i < N_READS; ++i) { - vals[i] = cast_to(in[i]); + values[i] = cast_to(in[i]); } } } } -template +template inline __device__ void -store_vals(int index, T* out, T (&vals)[N_READS], int size, int offset = 0) { +store_values(int index, T* out, T (&values)[N_READS], int size) { int start = index * N_READS + offset; int remaining = size - start; if constexpr (reverse) { @@ -77,12 +77,12 @@ store_vals(int index, T* out, T (&vals)[N_READS], int size, int offset = 0) { if (remaining < N_READS) { for (int i = 0; i < N_READS; ++i) { if (N_READS - i - 1 < remaining) { - out[i] = vals[N_READS - i - 1]; + out[i] = values[N_READS - i - 1]; } } } else { for (int i = 0; i < N_READS; ++i) { - out[i] = vals[N_READS - i - 1]; + out[i] = values[N_READS - i - 1]; } } } else { @@ -90,12 +90,12 @@ store_vals(int index, T* out, T (&vals)[N_READS], int size, int offset = 0) { if (remaining < N_READS) { for (int i = 0; i < N_READS; ++i) { if (i < remaining) { - out[i] = vals[i]; + out[i] = values[i]; } } } else { for (int i = 0; i < N_READS; ++i) { - out[i] = vals[i]; + out[i] = values[i]; } } } @@ -125,24 +125,24 @@ __global__ void contiguous_scan(const T* in, U* out, int32_t axis_size) { // Scan per block. for (int r = 0; r < cuda::ceil_div(axis_size, block.size() * N_READS); ++r) { int32_t index = r * block.size() + block.thread_rank(); - U vals[N_READS]; - load_vals(index, in, vals, axis_size, init); + U values[N_READS]; + load_values(index, in, values, axis_size, init); // Compute an inclusive scan per thread. - for (int i = 1; i < N_READS; i++) { - vals[i] = op(vals[i], vals[i - 1]); + for (int i = 1; i < N_READS; ++i) { + values[i] = op(values[i], values[i - 1]); } // Compute exclusive scan of thread sums. - U prev_thread_sum = cg::exclusive_scan(warp, vals[N_READS - 1], op); + U prev_thread_sum = cg::exclusive_scan(warp, values[N_READS - 1], op); if (warp.thread_rank() == 0) { prev_thread_sum = init; } // Write wrap's sum to shared memory. - if (warp.thread_rank() == warp.size() - 1) { + if (warp.thread_rank() == WARP_SIZE - 1) { warp_sums[warp.meta_group_rank()] = - op(prev_thread_sum, vals[N_READS - 1]); + op(prev_thread_sum, values[N_READS - 1]); } block.sync(); @@ -159,16 +159,16 @@ __global__ void contiguous_scan(const T* in, U* out, int32_t axis_size) { // Compute the output. for (int i = 0; i < N_READS; ++i) { - vals[i] = op(vals[i], prefix); - vals[i] = op(vals[i], warp_sums[warp.meta_group_rank()]); - vals[i] = op(vals[i], prev_thread_sum); + values[i] = op(values[i], prefix); + values[i] = op(values[i], warp_sums[warp.meta_group_rank()]); + values[i] = op(values[i], prev_thread_sum); } // Write the values. if (inclusive) { - store_vals(index, out, vals, axis_size); + store_values(index, out, values, axis_size); } else { - store_vals(index, out, vals, axis_size, 1); + store_values(index, out, values, axis_size); if (reverse) { if (block.thread_rank() == 0 && index == 0) { out[axis_size - 1] = init; @@ -183,14 +183,141 @@ __global__ void contiguous_scan(const T* in, U* out, int32_t axis_size) { // Share the prefix. if ((warp.meta_group_rank() == warp.meta_group_size() - 1) && - (warp.thread_rank() == warp.size() - 1)) { - warp_sums[0] = vals[N_READS - 1]; + (warp.thread_rank() == WARP_SIZE - 1)) { + warp_sums[0] = values[N_READS - 1]; } block.sync(); prefix = warp_sums[0]; } } +template < + typename T, + typename U, + typename Op, + int N_READS, + int BM, + int BN, + bool inclusive, + bool reverse> +__global__ void strided_scan( + const T* in, + U* out, + int32_t axis_size, + int64_t stride, + int64_t stride_blocks) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + constexpr int BN_pad = WARP_SIZE + 16 / sizeof(U); + constexpr int n_warps = BN / N_READS; + constexpr int n_scans = BN / n_warps; + + __shared__ U read_buffer[BM * BN_pad]; + + Op op; + U init = ReduceInit::value(); + U values[n_scans]; + U prefix[n_scans]; + for (int i = 0; i < n_scans; ++i) { + prefix[i] = init; + } + + // Compute offsets. + int64_t offset = (grid.block_rank() / stride_blocks) * axis_size * stride; + int64_t global_index_x = (grid.block_rank() % stride_blocks) * BN; + uint read_offset_y = (block.thread_rank() * N_READS) / BN; + uint read_offset_x = (block.thread_rank() * N_READS) % BN; + uint scan_offset_y = warp.thread_rank(); + uint scan_offset_x = warp.meta_group_rank() * n_scans; + + uint stride_limit = stride - global_index_x; + in += offset + global_index_x + read_offset_x; + out += offset + global_index_x + read_offset_x; + U* read_into = read_buffer + read_offset_y * BN_pad + read_offset_x; + U* read_from = read_buffer + scan_offset_y * BN_pad + scan_offset_x; + + for (uint j = 0; j < axis_size; j += BM) { + // Calculate the indices for the current thread. + uint index_y = j + read_offset_y; + uint check_index_y = index_y; + if (reverse) { + index_y = axis_size - 1 - index_y; + } + + // Read in SM. + if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { + for (int i = 0; i < N_READS; ++i) { + read_into[i] = in[index_y * stride + i]; + } + } else { + for (int i = 0; i < N_READS; ++i) { + if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { + read_into[i] = in[index_y * stride + i]; + } else { + read_into[i] = init; + } + } + } + block.sync(); + + // Read strided into registers. + for (int i = 0; i < n_scans; ++i) { + values[i] = read_from[i]; + } + warp.sync(); + + // Perform the scan. + for (int i = 0; i < n_scans; ++i) { + values[i] = cg::inclusive_scan(warp, values[i], op); + values[i] = op(values[i], prefix[i]); + prefix[i] = warp.shfl(values[i], WARP_SIZE - 1); + } + + // Write to SM. + for (int i = 0; i < n_scans; ++i) { + read_from[i] = values[i]; + } + block.sync(); + + // Write to device memory. + if (!inclusive) { + if (check_index_y == 0) { + if ((read_offset_x + N_READS) < stride_limit) { + for (int i = 0; i < N_READS; ++i) { + out[index_y * stride + i] = init; + } + } else { + for (int i = 0; i < N_READS; ++i) { + if ((read_offset_x + i) < stride_limit) { + out[index_y * stride + i] = init; + } + } + } + } + if (reverse) { + index_y -= 1; + check_index_y += 1; + } else { + index_y += 1; + check_index_y += 1; + } + } + if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { + for (int i = 0; i < N_READS; ++i) { + out[index_y * stride + i] = read_into[i]; + } + } else { + for (int i = 0; i < N_READS; ++i) { + if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { + out[index_y * stride + i] = read_into[i]; + } + } + } + } +} + } // namespace cu template @@ -259,6 +386,8 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { out.copy_shared_buffer(in); } + constexpr int N_READS = 4; + int32_t axis_size = in.shape(axis_); bool contiguous = in.strides()[axis_] == 1; auto& encoder = cu::get_command_encoder(s); @@ -274,7 +403,6 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { dispatch_bool(inclusive_, [&](auto inclusive) { dispatch_bool(reverse_, [&](auto reverse) { if (contiguous) { - constexpr int N_READS = 4; auto kernel = cu::contiguous_scan< T, U, @@ -282,7 +410,6 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { N_READS, inclusive.value, reverse.value>; - int32_t axis_size = in.shape(axis_); int block_dim = cuda::ceil_div(axis_size, N_READS); block_dim = cuda::ceil_div(block_dim, WARP_SIZE) * WARP_SIZE; block_dim = std::min(block_dim, WARP_SIZE * WARP_SIZE); @@ -294,7 +421,32 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { out.data(), axis_size); } else { - throw std::runtime_error("Strided Scan NYI"); + constexpr int BM = WARP_SIZE; + constexpr int BN = WARP_SIZE; + auto kernel = cu::strided_scan< + T, + U, + Op, + N_READS, + BM, + BN, + inclusive.value, + reverse.value>; + int64_t stride = in.strides()[axis_]; + int64_t stride_blocks = cuda::ceil_div(stride, BN); + dim3 num_blocks = get_2d_grid_dims( + in.shape(), in.strides(), axis_size * stride); + num_blocks.x *= stride_blocks; + int block_dim = BN / N_READS * WARP_SIZE; + encoder.add_kernel_node( + kernel, + num_blocks, + block_dim, + in.data(), + out.data(), + axis_size, + stride, + stride_blocks); } }); });