diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index fb81079a5..1188faca6 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -89,9 +89,13 @@ template < int NDIM, int BM, int BN, - int N_READS = 4> -__global__ void -col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) { + int N_READS = 4, + int BLOCKS = 1> +__global__ void col_reduce_looped( + T* in, + U* out, + const __grid_constant__ ColReduceArgs args, + int64_t out_size) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); @@ -102,6 +106,8 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) { size_t tile_idx = grid.block_rank(); size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN); size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN); + size_t tile_out = tile_y / out_size; + tile_y = tile_y % out_size; // Compute the indices for the thread within the tile short thread_x = block.thread_rank() % threads_per_row; @@ -118,12 +124,23 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) { totals[i] = ReduceInit::value(); } - LoopedElemToLoc 2)> loop(args.reduce_ndim); - loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data()); size_t total = args.non_col_reductions * args.reduction_size; + size_t per_block, start, end; + if constexpr (BLOCKS > 1) { + per_block = (total + BLOCKS - 1) / BLOCKS; + start = tile_out * per_block + thread_y; + end = min((tile_out + 1) * per_block, total); + } else { + per_block = total; + start = thread_y; + end = total; + } + + LoopedElemToLoc 2)> loop(args.reduce_ndim); + loop.next(start, args.reduce_shape.data(), args.reduce_strides.data()); if (tile_x * BN + BN <= args.reduction_stride) { if (args.reduction_stride % N_READS == 0) { - for (size_t r = thread_y; r < total; r += BM) { + for (size_t r = start; r < end; r += BM) { T vals[N_READS]; cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals); for (int i = 0; i < N_READS; i++) { @@ -132,7 +149,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) { loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); } } else { - for (size_t r = thread_y; r < total; r += BM) { + for (size_t r = start; r < end; r += BM) { T vals[N_READS]; cub::LoadDirectBlocked(thread_x, in + loop.location(), vals); for (int i = 0; i < N_READS; i++) { @@ -142,7 +159,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) { } } } else { - for (size_t r = thread_y; r < total; r += BM) { + for (size_t r = start; r < end; r += BM) { T vals[N_READS]; cub::LoadDirectBlocked( thread_x, @@ -173,6 +190,9 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) { // Write result. if (warp.thread_rank() == 0) { + if (BLOCKS > 1) { + out += tile_out * out_size * args.reduction_stride; + } cub::StoreDirectBlocked( warp.meta_group_rank(), out + tile_y * args.reduction_stride + tile_x * BN, @@ -227,11 +247,12 @@ __global__ void col_reduce_small( inline auto output_grid_for_col_reduce( const array& out, const cu::ColReduceArgs& args, - int bn) { + int bn, + int outer = 1) { int gx, gy = 1; size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn); size_t n_outer_blocks = out.size() / args.reduction_stride; - size_t n_blocks = n_outer_blocks * n_inner_blocks; + size_t n_blocks = n_outer_blocks * n_inner_blocks * outer; while (n_blocks / gy > INT32_MAX) { gy *= 2; } @@ -277,7 +298,8 @@ void col_reduce_looped( 0, indata, gpu_ptr(out), - static_cast(args)); + static_cast(args), + out.size() / args.reduction_stride); }); }); }); @@ -320,6 +342,117 @@ void col_reduce_small( }); } +void col_reduce_two_pass( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan, + const cu::ColReduceArgs& args) { + // Allocate data for the output using in's layout to access them as + // contiguously as possible. + allocate_same_layout(out, in, axes, encoder); + + // Allocate an intermediate array to hold the 1st pass result + constexpr int outer = 32; + + Shape intermediate_shape; + intermediate_shape.push_back(outer); + intermediate_shape.insert( + intermediate_shape.end(), out.shape().begin(), out.shape().end()); + + Strides intermediate_strides; + intermediate_strides.push_back(out.size()); + intermediate_strides.insert( + intermediate_strides.end(), out.strides().begin(), out.strides().end()); + + array intermediate(intermediate_shape, out.dtype(), nullptr, {}); + auto [data_size, rc, cc] = + check_contiguity(intermediate_shape, intermediate_strides); + auto fl = out.flags(); + fl.row_contiguous = rc; + fl.col_contiguous = cc; + fl.contiguous = true; + intermediate.set_data( + cu::malloc_async(intermediate.nbytes(), encoder), + data_size, + intermediate_strides, + fl, + allocator::free); + + encoder.add_temporary(intermediate); + encoder.set_input_array(in); + encoder.set_output_array(intermediate); + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + // Cub doesn't like const pointers for vectorized loads. (sigh) + T* indata = const_cast(gpu_ptr(in)); + + constexpr int N_READS = 4; + constexpr int BM = 32; + constexpr int BN = 32; + dim3 grid = output_grid_for_col_reduce(out, args, BN, outer); + int blocks = BM * BN / N_READS; + auto kernel = cu:: + col_reduce_looped; + encoder.add_kernel_node( + kernel, + grid, + blocks, + 0, + indata, + gpu_ptr(intermediate), + static_cast(args), + out.size() / args.reduction_stride); + }); + }); + }); + + // Prepare the reduction arguments for the 2nd pass + cu::ColReduceArgs second_args = args; + second_args.reduction_size = outer; + second_args.reduction_stride = out.size(); + second_args.ndim = 0; + second_args.reduce_shape[0] = outer; + second_args.reduce_strides[0] = out.size(); + second_args.reduce_ndim = 1; + second_args.non_col_reductions = 1; + + encoder.set_input_array(intermediate); + encoder.set_output_array(out); + dispatch_all_types(intermediate.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + dispatch_reduce_ndim(second_args.reduce_ndim, [&](auto reduce_ndim) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + + constexpr int N_READS = 4; + constexpr int BM = 32; + constexpr int BN = 32; + dim3 grid = output_grid_for_col_reduce(out, second_args, BN); + int blocks = BM * BN / N_READS; + auto kernel = + cu::col_reduce_looped; + encoder.add_kernel_node( + kernel, + grid, + blocks, + 0, + gpu_ptr(intermediate), + gpu_ptr(out), + second_args, + second_args.reduction_stride); + }); + }); + }); +} + void col_reduce( cu::CommandEncoder& encoder, const array& in, @@ -334,6 +467,18 @@ void col_reduce( // It is a general strided reduce. Each threadblock computes the output for // a subrow of the fast moving axis. For instance 32 elements. // + // - col_reduce_small + // + // It is a column reduce for small columns. Each thread loops over the whole + // column without communicating with any other thread. + // + // - col_reduce_two_pass + // + // It is a reduce for long columns. To increase parallelism, we split the + // reduction in two passes. First we do a column reduce where many + // threadblocks operate on different parts of the reduced axis. Then we + // perform a final column reduce. + // // Notes: As in row reduce we opt to read as much in order as possible and // leave transpositions as they are (contrary to our Metal backend). // @@ -349,6 +494,14 @@ void col_reduce( return; } + // Long column with smallish row + size_t total_sums = args.non_col_reductions * args.reduction_size; + size_t approx_threads = out.size(); + if (total_sums / approx_threads > 32) { + col_reduce_two_pass(encoder, in, out, reduce_type, axes, plan, args); + return; + } + // Fallback col reduce col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args); } diff --git a/python/tests/test_reduce.py b/python/tests/test_reduce.py index d6ddf353b..24d5a688d 100644 --- a/python/tests/test_reduce.py +++ b/python/tests/test_reduce.py @@ -210,6 +210,14 @@ class TestReduce(mlx_tests.MLXTestCase): ref = getattr(np, op)(np_arr, axis=axis) self.assertTrue(np.array_equal(out, ref, equal_nan=True)) + def test_long_column(self): + a = (np.random.randn(8192, 64) * 32).astype(np.int32) + b = mx.array(a) + + c1 = a.sum(0) + c2 = b.sum(0) + self.assertTrue(np.all(c1 == c2)) + if __name__ == "__main__": mlx_tests.MLXTestRunner(failfast=True)