mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
160 lines
4.5 KiB
Plaintext
160 lines
4.5 KiB
Plaintext
// Copyright © 2025 Apple Inc.
|
|
|
|
#include "mlx/backend/cuda/device.h"
|
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
|
|
|
#include <cooperative_groups.h>
|
|
#include <cooperative_groups/reduce.h>
|
|
#include <cub/block/block_load.cuh>
|
|
|
|
namespace mlx::core {
|
|
|
|
namespace cu {
|
|
|
|
namespace cg = cooperative_groups;
|
|
|
|
template <typename T, typename U, typename ReduceOp, int N = 4>
|
|
__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
|
|
// TODO: Process multiple "rows" in each thread
|
|
constexpr int M = 1;
|
|
|
|
auto grid = cg::this_grid();
|
|
auto block = cg::this_thread_block();
|
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
|
|
|
const U init = cu::ReduceInit<ReduceOp, T>::value();
|
|
ReduceOp op;
|
|
|
|
T vals[N];
|
|
U accs[M];
|
|
accs[0] = init;
|
|
|
|
size_t start = grid.block_rank() * block_step;
|
|
size_t end = start + block_step;
|
|
size_t check = min(end, size);
|
|
|
|
size_t i = start;
|
|
for (; i + block.size() * N <= check; i += block.size() * N) {
|
|
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
|
|
for (int j = 0; j < N; j++) {
|
|
accs[0] = op(accs[0], cast_to<U>(vals[j]));
|
|
}
|
|
}
|
|
|
|
if (i < check) {
|
|
cub::LoadDirectBlocked(
|
|
block.thread_rank(), in + i, vals, check - i, cast_to<T>(init));
|
|
for (int i = 0; i < N; i++) {
|
|
accs[0] = op(accs[0], cast_to<U>(vals[i]));
|
|
}
|
|
}
|
|
|
|
__shared__ U shared_accumulators[32];
|
|
block_reduce(block, warp, accs, shared_accumulators, op, init);
|
|
|
|
if (block.thread_rank() == 0) {
|
|
out[grid.block_rank()] = accs[0];
|
|
}
|
|
}
|
|
|
|
} // namespace cu
|
|
|
|
void all_reduce(
|
|
cu::CommandEncoder& encoder,
|
|
const array& in,
|
|
array& out,
|
|
Reduce::ReduceType reduce_type) {
|
|
constexpr int N_READS = 8;
|
|
|
|
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
|
|
|
auto get_args = [](size_t size, int N) {
|
|
int threads = std::min(512UL, (size + N - 1) / N);
|
|
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
|
int reductions_per_step = threads * N;
|
|
size_t steps_needed =
|
|
(size + reductions_per_step - 1) / reductions_per_step;
|
|
|
|
int blocks;
|
|
if (steps_needed < 32) {
|
|
blocks = 1;
|
|
} else if (steps_needed < 128) {
|
|
blocks = 32;
|
|
} else if (steps_needed < 512) {
|
|
blocks = 128;
|
|
} else if (steps_needed < 1024) {
|
|
blocks = 512;
|
|
} else {
|
|
blocks = 1024;
|
|
}
|
|
|
|
size_t steps_per_block = (steps_needed + blocks - 1) / blocks;
|
|
size_t block_step = steps_per_block * reductions_per_step;
|
|
|
|
return std::make_tuple(blocks, threads, block_step);
|
|
};
|
|
|
|
int blocks, threads;
|
|
size_t block_step;
|
|
size_t insize = in.size();
|
|
Dtype dt = in.dtype();
|
|
|
|
// Cub doesn't like const pointers for load (sigh).
|
|
void* indata = const_cast<void*>(gpu_ptr<void>(in));
|
|
|
|
// Large array so allocate an intermediate and accumulate there
|
|
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
|
|
encoder.set_input_array(in);
|
|
if (blocks > 1) {
|
|
array intermediate({blocks}, out.dtype(), nullptr, {});
|
|
intermediate.set_data(cu::malloc_async(intermediate.nbytes(), encoder));
|
|
encoder.add_temporary(intermediate);
|
|
encoder.set_output_array(intermediate);
|
|
dispatch_all_types(dt, [&](auto type_tag) {
|
|
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
|
using OP = MLX_GET_TYPE(reduce_type_tag);
|
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
|
using U = typename cu::ReduceResult<OP, T>::type;
|
|
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
|
encoder.add_kernel_node(
|
|
kernel,
|
|
blocks,
|
|
threads,
|
|
0,
|
|
static_cast<T*>(indata),
|
|
gpu_ptr<U>(intermediate),
|
|
block_step,
|
|
insize);
|
|
});
|
|
});
|
|
|
|
// Set the input for the next step and recalculate the blocks
|
|
indata = gpu_ptr<void>(intermediate);
|
|
dt = intermediate.dtype();
|
|
insize = intermediate.size();
|
|
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
|
|
encoder.set_input_array(intermediate);
|
|
}
|
|
|
|
encoder.set_output_array(out);
|
|
dispatch_all_types(dt, [&](auto type_tag) {
|
|
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
|
using OP = MLX_GET_TYPE(reduce_type_tag);
|
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
|
using U = typename cu::ReduceResult<OP, T>::type;
|
|
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
|
encoder.add_kernel_node(
|
|
kernel,
|
|
blocks,
|
|
threads,
|
|
0,
|
|
static_cast<T*>(indata),
|
|
gpu_ptr<U>(out),
|
|
block_step,
|
|
insize);
|
|
});
|
|
});
|
|
}
|
|
|
|
} // namespace mlx::core
|