mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add a 2-pass col reduce for CUDA (#2863)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
This commit is contained in:
committed by
GitHub
parent
1fa8dc5797
commit
997cfc7699
@@ -89,9 +89,13 @@ template <
|
|||||||
int NDIM,
|
int NDIM,
|
||||||
int BM,
|
int BM,
|
||||||
int BN,
|
int BN,
|
||||||
int N_READS = 4>
|
int N_READS = 4,
|
||||||
__global__ void
|
int BLOCKS = 1>
|
||||||
col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
__global__ void col_reduce_looped(
|
||||||
|
T* in,
|
||||||
|
U* out,
|
||||||
|
const __grid_constant__ ColReduceArgs args,
|
||||||
|
int64_t out_size) {
|
||||||
auto grid = cg::this_grid();
|
auto grid = cg::this_grid();
|
||||||
auto block = cg::this_thread_block();
|
auto block = cg::this_thread_block();
|
||||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
@@ -102,6 +106,8 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
|||||||
size_t tile_idx = grid.block_rank();
|
size_t tile_idx = grid.block_rank();
|
||||||
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
|
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
|
||||||
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
|
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
|
||||||
|
size_t tile_out = tile_y / out_size;
|
||||||
|
tile_y = tile_y % out_size;
|
||||||
|
|
||||||
// Compute the indices for the thread within the tile
|
// Compute the indices for the thread within the tile
|
||||||
short thread_x = block.thread_rank() % threads_per_row;
|
short thread_x = block.thread_rank() % threads_per_row;
|
||||||
@@ -118,12 +124,23 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
|||||||
totals[i] = ReduceInit<Op, T>::value();
|
totals[i] = ReduceInit<Op, T>::value();
|
||||||
}
|
}
|
||||||
|
|
||||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
|
||||||
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
|
|
||||||
size_t total = args.non_col_reductions * args.reduction_size;
|
size_t total = args.non_col_reductions * args.reduction_size;
|
||||||
|
size_t per_block, start, end;
|
||||||
|
if constexpr (BLOCKS > 1) {
|
||||||
|
per_block = (total + BLOCKS - 1) / BLOCKS;
|
||||||
|
start = tile_out * per_block + thread_y;
|
||||||
|
end = min((tile_out + 1) * per_block, total);
|
||||||
|
} else {
|
||||||
|
per_block = total;
|
||||||
|
start = thread_y;
|
||||||
|
end = total;
|
||||||
|
}
|
||||||
|
|
||||||
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
|
loop.next(start, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
if (tile_x * BN + BN <= args.reduction_stride) {
|
if (tile_x * BN + BN <= args.reduction_stride) {
|
||||||
if (args.reduction_stride % N_READS == 0) {
|
if (args.reduction_stride % N_READS == 0) {
|
||||||
for (size_t r = thread_y; r < total; r += BM) {
|
for (size_t r = start; r < end; r += BM) {
|
||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
|
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
@@ -132,7 +149,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
|||||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t r = thread_y; r < total; r += BM) {
|
for (size_t r = start; r < end; r += BM) {
|
||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
|
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
@@ -142,7 +159,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t r = thread_y; r < total; r += BM) {
|
for (size_t r = start; r < end; r += BM) {
|
||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
cub::LoadDirectBlocked(
|
cub::LoadDirectBlocked(
|
||||||
thread_x,
|
thread_x,
|
||||||
@@ -173,6 +190,9 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
|||||||
|
|
||||||
// Write result.
|
// Write result.
|
||||||
if (warp.thread_rank() == 0) {
|
if (warp.thread_rank() == 0) {
|
||||||
|
if (BLOCKS > 1) {
|
||||||
|
out += tile_out * out_size * args.reduction_stride;
|
||||||
|
}
|
||||||
cub::StoreDirectBlocked(
|
cub::StoreDirectBlocked(
|
||||||
warp.meta_group_rank(),
|
warp.meta_group_rank(),
|
||||||
out + tile_y * args.reduction_stride + tile_x * BN,
|
out + tile_y * args.reduction_stride + tile_x * BN,
|
||||||
@@ -227,11 +247,12 @@ __global__ void col_reduce_small(
|
|||||||
inline auto output_grid_for_col_reduce(
|
inline auto output_grid_for_col_reduce(
|
||||||
const array& out,
|
const array& out,
|
||||||
const cu::ColReduceArgs& args,
|
const cu::ColReduceArgs& args,
|
||||||
int bn) {
|
int bn,
|
||||||
|
int outer = 1) {
|
||||||
int gx, gy = 1;
|
int gx, gy = 1;
|
||||||
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
|
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
|
||||||
size_t n_outer_blocks = out.size() / args.reduction_stride;
|
size_t n_outer_blocks = out.size() / args.reduction_stride;
|
||||||
size_t n_blocks = n_outer_blocks * n_inner_blocks;
|
size_t n_blocks = n_outer_blocks * n_inner_blocks * outer;
|
||||||
while (n_blocks / gy > INT32_MAX) {
|
while (n_blocks / gy > INT32_MAX) {
|
||||||
gy *= 2;
|
gy *= 2;
|
||||||
}
|
}
|
||||||
@@ -277,7 +298,8 @@ void col_reduce_looped(
|
|||||||
0,
|
0,
|
||||||
indata,
|
indata,
|
||||||
gpu_ptr<U>(out),
|
gpu_ptr<U>(out),
|
||||||
static_cast<cu::ColReduceArgs>(args));
|
static_cast<cu::ColReduceArgs>(args),
|
||||||
|
out.size() / args.reduction_stride);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -320,6 +342,117 @@ void col_reduce_small(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void col_reduce_two_pass(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan,
|
||||||
|
const cu::ColReduceArgs& args) {
|
||||||
|
// Allocate data for the output using in's layout to access them as
|
||||||
|
// contiguously as possible.
|
||||||
|
allocate_same_layout(out, in, axes, encoder);
|
||||||
|
|
||||||
|
// Allocate an intermediate array to hold the 1st pass result
|
||||||
|
constexpr int outer = 32;
|
||||||
|
|
||||||
|
Shape intermediate_shape;
|
||||||
|
intermediate_shape.push_back(outer);
|
||||||
|
intermediate_shape.insert(
|
||||||
|
intermediate_shape.end(), out.shape().begin(), out.shape().end());
|
||||||
|
|
||||||
|
Strides intermediate_strides;
|
||||||
|
intermediate_strides.push_back(out.size());
|
||||||
|
intermediate_strides.insert(
|
||||||
|
intermediate_strides.end(), out.strides().begin(), out.strides().end());
|
||||||
|
|
||||||
|
array intermediate(intermediate_shape, out.dtype(), nullptr, {});
|
||||||
|
auto [data_size, rc, cc] =
|
||||||
|
check_contiguity(intermediate_shape, intermediate_strides);
|
||||||
|
auto fl = out.flags();
|
||||||
|
fl.row_contiguous = rc;
|
||||||
|
fl.col_contiguous = cc;
|
||||||
|
fl.contiguous = true;
|
||||||
|
intermediate.set_data(
|
||||||
|
cu::malloc_async(intermediate.nbytes(), encoder),
|
||||||
|
data_size,
|
||||||
|
intermediate_strides,
|
||||||
|
fl,
|
||||||
|
allocator::free);
|
||||||
|
|
||||||
|
encoder.add_temporary(intermediate);
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(intermediate);
|
||||||
|
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||||
|
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||||
|
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
|
||||||
|
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||||
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
using U = typename cu::ReduceResult<OP, T>::type;
|
||||||
|
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||||
|
T* indata = const_cast<T*>(gpu_ptr<T>(in));
|
||||||
|
|
||||||
|
constexpr int N_READS = 4;
|
||||||
|
constexpr int BM = 32;
|
||||||
|
constexpr int BN = 32;
|
||||||
|
dim3 grid = output_grid_for_col_reduce(out, args, BN, outer);
|
||||||
|
int blocks = BM * BN / N_READS;
|
||||||
|
auto kernel = cu::
|
||||||
|
col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS, outer>;
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
grid,
|
||||||
|
blocks,
|
||||||
|
0,
|
||||||
|
indata,
|
||||||
|
gpu_ptr<U>(intermediate),
|
||||||
|
static_cast<cu::ColReduceArgs>(args),
|
||||||
|
out.size() / args.reduction_stride);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// Prepare the reduction arguments for the 2nd pass
|
||||||
|
cu::ColReduceArgs second_args = args;
|
||||||
|
second_args.reduction_size = outer;
|
||||||
|
second_args.reduction_stride = out.size();
|
||||||
|
second_args.ndim = 0;
|
||||||
|
second_args.reduce_shape[0] = outer;
|
||||||
|
second_args.reduce_strides[0] = out.size();
|
||||||
|
second_args.reduce_ndim = 1;
|
||||||
|
second_args.non_col_reductions = 1;
|
||||||
|
|
||||||
|
encoder.set_input_array(intermediate);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
dispatch_all_types(intermediate.dtype(), [&](auto type_tag) {
|
||||||
|
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||||
|
dispatch_reduce_ndim(second_args.reduce_ndim, [&](auto reduce_ndim) {
|
||||||
|
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||||
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
using U = typename cu::ReduceResult<OP, T>::type;
|
||||||
|
|
||||||
|
constexpr int N_READS = 4;
|
||||||
|
constexpr int BM = 32;
|
||||||
|
constexpr int BN = 32;
|
||||||
|
dim3 grid = output_grid_for_col_reduce(out, second_args, BN);
|
||||||
|
int blocks = BM * BN / N_READS;
|
||||||
|
auto kernel =
|
||||||
|
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
grid,
|
||||||
|
blocks,
|
||||||
|
0,
|
||||||
|
gpu_ptr<T>(intermediate),
|
||||||
|
gpu_ptr<U>(out),
|
||||||
|
second_args,
|
||||||
|
second_args.reduction_stride);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void col_reduce(
|
void col_reduce(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
const array& in,
|
const array& in,
|
||||||
@@ -334,6 +467,18 @@ void col_reduce(
|
|||||||
// It is a general strided reduce. Each threadblock computes the output for
|
// It is a general strided reduce. Each threadblock computes the output for
|
||||||
// a subrow of the fast moving axis. For instance 32 elements.
|
// a subrow of the fast moving axis. For instance 32 elements.
|
||||||
//
|
//
|
||||||
|
// - col_reduce_small
|
||||||
|
//
|
||||||
|
// It is a column reduce for small columns. Each thread loops over the whole
|
||||||
|
// column without communicating with any other thread.
|
||||||
|
//
|
||||||
|
// - col_reduce_two_pass
|
||||||
|
//
|
||||||
|
// It is a reduce for long columns. To increase parallelism, we split the
|
||||||
|
// reduction in two passes. First we do a column reduce where many
|
||||||
|
// threadblocks operate on different parts of the reduced axis. Then we
|
||||||
|
// perform a final column reduce.
|
||||||
|
//
|
||||||
// Notes: As in row reduce we opt to read as much in order as possible and
|
// Notes: As in row reduce we opt to read as much in order as possible and
|
||||||
// leave transpositions as they are (contrary to our Metal backend).
|
// leave transpositions as they are (contrary to our Metal backend).
|
||||||
//
|
//
|
||||||
@@ -349,6 +494,14 @@ void col_reduce(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Long column with smallish row
|
||||||
|
size_t total_sums = args.non_col_reductions * args.reduction_size;
|
||||||
|
size_t approx_threads = out.size();
|
||||||
|
if (total_sums / approx_threads > 32) {
|
||||||
|
col_reduce_two_pass(encoder, in, out, reduce_type, axes, plan, args);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Fallback col reduce
|
// Fallback col reduce
|
||||||
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
|
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -210,6 +210,14 @@ class TestReduce(mlx_tests.MLXTestCase):
|
|||||||
ref = getattr(np, op)(np_arr, axis=axis)
|
ref = getattr(np, op)(np_arr, axis=axis)
|
||||||
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
|
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
|
||||||
|
|
||||||
|
def test_long_column(self):
|
||||||
|
a = (np.random.randn(8192, 64) * 32).astype(np.int32)
|
||||||
|
b = mx.array(a)
|
||||||
|
|
||||||
|
c1 = a.sum(0)
|
||||||
|
c2 = b.sum(0)
|
||||||
|
self.assertTrue(np.all(c1 == c2))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner(failfast=True)
|
mlx_tests.MLXTestRunner(failfast=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user