diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 35f2287d6..1ae46d0a3 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -7,8 +7,6 @@ #include #include -#include -#include namespace mlx::core { @@ -83,7 +81,8 @@ struct RowReduceArgs { }; template -__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { +__global__ void +row_reduce_simple(const T* in, U* out, size_t n_rows, int size) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); @@ -91,8 +90,8 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { const U init = cu::ReduceInit::value(); ReduceOp op; - T vals[M][N]; - U accs[M]; + AlignedVector vals[M]; + AlignedVector accs; for (int i = 0; i < M; i++) { accs[i] = init; } @@ -101,43 +100,31 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { min(n_rows - M, static_cast(grid.block_rank() * M)); const size_t full_blocks = size / (block.size() * N); const size_t final_offset = full_blocks * (block.size() * N); - in += start_row * size; + in += start_row * size + block.thread_rank() * N; out += start_row; - if (size % N == 0) { - for (size_t r = 0; r < full_blocks; r++) { - for (int k = 0; k < M; k++) { - cub::LoadDirectBlockedVectorized( - block.thread_rank(), - in + k * size + r * (block.size() * N), - vals[k]); - for (int j = 0; j < N; j++) { - accs[k] = op(accs[k], cast_to(vals[k][j])); - } - } - } - } else { - for (size_t r = 0; r < full_blocks; r++) { - for (int k = 0; k < M; k++) { - cub::LoadDirectBlocked( - block.thread_rank(), - in + k * size + r * (block.size() * N), - vals[k]); - for (int j = 0; j < N; j++) { - accs[k] = op(accs[k], cast_to(vals[k][j])); - } + for (size_t r = 0; r < full_blocks; r++) { + for (int k = 0; k < M; k++) { + vals[k] = load_vector(in + k * size, 0); + } + for (int k = 0; k < M; k++) { + for (int j = 0; j < N; j++) { + accs[k] = op(accs[k], cast_to(vals[k][j])); } } + + in += block.size() * N; } if (final_offset < size) { for (int k = 0; k < M; k++) { - cub::LoadDirectBlocked( - block.thread_rank(), - in + k * size + final_offset, - vals[k], - size, - cast_to(init)); + for (int i = 0; i < N; i++) { + vals[k][i] = ((final_offset + block.thread_rank() * N + i) < size) + ? in[k * size + i] + : cast_to(init); + } + } + for (int k = 0; k < M; k++) { for (int j = 0; j < N; j++) { accs[k] = op(accs[k], cast_to(vals[k][j])); } @@ -145,13 +132,11 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { } __shared__ U shared_accumulators[32 * M]; - block_reduce(block, warp, accs, shared_accumulators, op, init); + block_reduce(block, warp, accs.val, shared_accumulators, op, init); if (block.thread_rank() == 0) { if (grid.block_rank() * M + M <= n_rows) { - for (int i = 0; i < M; i++) { - out[i] = accs[i]; - } + store_vector(out, 0, accs); } else { short offset = grid.block_rank() * M + M - n_rows; for (int i = offset; i < M; i++) { @@ -161,17 +146,10 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { } } -template < - typename T, - typename U, - typename Op, - int NDIM, - int BLOCK_DIM, - int N_READS = 4> +template __global__ void row_reduce_looped( - T* in, + const T* in, U* out, - size_t out_size, const __grid_constant__ RowReduceArgs args) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); @@ -185,36 +163,60 @@ __global__ void row_reduce_looped( U init = ReduceInit::value(); total[0] = init; LoopedElemToLoc 2)> loop(args.reduce_ndim); - size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS); - size_t final_offset = full_blocks * BLOCK_DIM * N_READS; + const size_t full_blocks = args.row_size / (block.size() * N_READS); + const size_t final_offset = full_blocks * (block.size() * N_READS); in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + in += block.thread_rank() * N_READS; - for (size_t n = 0; n < args.non_row_reductions; n++) { - for (size_t r = 0; r < full_blocks; r++) { - T vals[N_READS]; - cub::LoadDirectBlockedVectorized( - block.thread_rank(), - in + loop.location() + r * BLOCK_DIM * N_READS, - vals); - for (int i = 0; i < N_READS; i++) { - total[0] = op(total[0], cast_to(vals[i])); - } + // Unaligned reduce + if (final_offset < args.row_size) { + bool mask[N_READS]; + for (int i = 0; i < N_READS; i++) { + mask[i] = + (final_offset + block.thread_rank() * N_READS + i) < args.row_size; } - if (final_offset < args.row_size) { - T vals[N_READS]; - cub::LoadDirectBlocked( - block.thread_rank(), - in + loop.location() + final_offset, - vals, - args.row_size - final_offset, - cast_to(init)); - for (int i = 0; i < N_READS; i++) { - total[0] = op(total[0], cast_to(vals[i])); + + for (size_t n = 0; n < args.non_row_reductions; n++) { + const T* inlocal = in + loop.location(); + + for (size_t r = 0; r < full_blocks; r++) { + auto vals = load_vector(inlocal, 0); + for (int i = 0; i < N_READS; i++) { + total[0] = op(total[0], cast_to(vals[i])); + } + inlocal += block.size() * N_READS; } + + { + T vals[N_READS]; + for (int i = 0; i < N_READS; i++) { + vals[i] = mask[i] ? inlocal[i] : cast_to(init); + } + for (int i = 0; i < N_READS; i++) { + total[0] = op(total[0], cast_to(vals[i])); + } + } + + loop.next(args.reduce_shape.data(), args.reduce_strides.data()); + } + } + + // Aligned case + else { + for (size_t n = 0; n < args.non_row_reductions; n++) { + const T* inlocal = in + loop.location(); + + for (size_t r = 0; r < full_blocks; r++) { + auto vals = load_vector(inlocal, 0); + for (int i = 0; i < N_READS; i++) { + total[0] = op(total[0], cast_to(vals[i])); + } + inlocal += block.size() * N_READS; + } + + loop.next(args.reduce_shape.data(), args.reduce_strides.data()); } - // TODO: Maybe block.sync() here? - loop.next(args.reduce_shape.data(), args.reduce_strides.data()); } __shared__ U shared_accumulators[32]; @@ -234,8 +236,6 @@ void row_reduce_simple( Reduce::ReduceType reduce_type, const std::vector& axes, const ReductionPlan& plan) { - constexpr int N_READS = 8; - // Allocate data for the output using in's layout to avoid elem_to_loc in the // kernel. allocate_same_layout(out, in, axes); @@ -250,14 +250,15 @@ void row_reduce_simple( using T = cuda_type_t; using U = typename cu::ReduceResult::type; - // Cub doesn't like const pointers for vectorized loads. (sigh) - T* indata = const_cast(in.data()); + constexpr int N_READS = 16 / sizeof(T); // Calculate the grid and block dims size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS; dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); - int threads = std::min(1024UL, reductions); - threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE; + warps /= 4; + warps = std::max(std::min(warps, 32), 1); + int threads = warps * WARP_SIZE; dim3 block(threads, 1, 1); // Pick the kernel @@ -267,6 +268,7 @@ void row_reduce_simple( kernel = cu::row_reduce_simple; } + T* indata = const_cast(in.data()); int size = plan.shape.back(); encoder.add_kernel_node( kernel, grid, block, 0, indata, out.data(), out.size(), size); @@ -282,8 +284,6 @@ void row_reduce_looped( const std::vector& axes, const ReductionPlan& plan, cu::RowReduceArgs args) { - constexpr int N_READS = 8; - // Allocate data for the output using in's layout to access them as // contiguously as possible. allocate_same_layout(out, in, axes); @@ -295,34 +295,27 @@ void row_reduce_looped( using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t; using U = typename cu::ReduceResult::type; - // Cub doesn't like const pointers for vectorized loads. (sigh) - T* indata = const_cast(in.data()); + + constexpr int N_READS = 16 / sizeof(T); // Calculate the grid and block dims args.sort_access_pattern(in, axes); dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); size_t reductions = (args.row_size + N_READS - 1) / N_READS; - int threads = std::min(1024UL, reductions); - threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE; + warps /= 4; + warps = std::max(std::min(warps, 32), 1); + int threads = warps * WARP_SIZE; dim3 block(threads, 1, 1); // Pick the kernel - auto kernel = cu::row_reduce_looped; + auto kernel = cu::row_reduce_looped; dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { - dispatch_block_dim(threads, [&](auto threads_constant) { - kernel = cu::row_reduce_looped< - T, - U, - OP, - reduce_ndim.value, - threads_constant.value, - N_READS>; - block.x = threads_constant.value; - }); + kernel = cu::row_reduce_looped; }); encoder.add_kernel_node( - kernel, grid, block, 0, indata, out.data(), out.size(), args); + kernel, grid, block, 0, in.data(), out.data(), args); }); }); }