Add all reduce and atomic updates

This commit is contained in:
Angelos Katharopoulos 2025-06-17 23:58:51 -07:00
parent ab7c310914
commit 9cf7ef1068
6 changed files with 259 additions and 22 deletions

View File

@ -29,6 +29,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu

View File

@ -157,7 +157,7 @@ void binary_op_gpu_inplace(
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel =
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(

View File

@ -21,29 +21,15 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(!axes_.empty());
assert(out.size() != in.size());
out.set_data(allocator::malloc(out.nbytes()));
// out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
// encoder.set_input_array(in);
// encoder.set_output_array(out);
// Fill out with init value.
if (in.size() == 0) {
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type_, OP, {
using InType = cuda_type_t<CTYPE>;
using OutType = cu::ReduceResult<OP, InType>::type;
thrust::fill_n(
cu::thrust_policy(stream),
thrust::device_pointer_cast(out.data<OutType>()),
out.data_size(),
cu::ReduceInit<OP, InType>::value());
});
});
});
return;
throw std::runtime_error("Should never reach here.");
}
// Reduce.
@ -59,8 +45,12 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
plan = get_reduction_plan(in, axes_);
}
if ((plan.type == ContiguousAllReduce) ||
(plan.type == ContiguousReduce && plan.shape.size() == 1)) {
if (plan.type == ContiguousAllReduce) {
all_reduce(encoder, in, out, reduce_type_);
return;
}
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
segmented_reduce(encoder, in, out, reduce_type_, axes_, plan);
return;
}

View File

@ -0,0 +1,160 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/reduce/reduce.cuh"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cub/block/block_load.cuh>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
namespace {
// TODO: Should make a custom complex type
template <typename U, typename T>
inline __device__ U __cast(T x) {
return static_cast<U>(x);
}
template <>
inline __device__ bool __cast<bool, cuComplex>(cuComplex x) {
return x.x != 0 && x.y != 0;
}
template <>
inline __device__ cuComplex __cast<cuComplex, bool>(bool x) {
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0);
}
} // namespace
template <typename T, typename U, typename ReduceOp, int N = 4>
__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
const U init = cu::ReduceInit<ReduceOp, T>::value();
ReduceOp op;
T vals[N];
U accs[N];
for (int i = 0; i < N; i++) {
accs[i] = init;
}
size_t start = grid.block_rank() * block_step;
size_t end = start + block_step;
size_t check = min(end, size);
for (size_t i = start; i + block.size() * N <= check; i += block.size() * N) {
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
for (int i = 0; i < N; i++) {
accs[i] = op(accs[i], __cast<U, T>(vals[i]));
}
}
if (end > size) {
size_t offset = end - block.size() * N;
int block_end = size - offset;
cub::LoadDirectBlocked(
block.thread_rank(), in + offset, vals, block_end, __cast<T, U>(init));
for (int i = 0; i < N; i++) {
accs[i] = op(accs[i], __cast<U, T>(vals[i]));
}
}
for (int i = 1; i < N; i++) {
accs[0] = op(accs[0], accs[i]);
}
accs[0] = cg::reduce(warp, accs[0], op);
__shared__ U shared_accumulators[32];
if (warp.thread_rank() == 0) {
shared_accumulators[warp.meta_group_rank()] = accs[0];
}
block.sync();
accs[0] = (warp.thread_rank() < warp.meta_group_size())
? shared_accumulators[warp.thread_rank()]
: init;
accs[0] = cg::reduce(warp, accs[0], op);
if (block.thread_rank() == 0) {
out[grid.block_rank()] = accs[0];
}
}
} // namespace cu
void all_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type) {
constexpr int N_READS = 4;
out.set_data(allocator::malloc(out.nbytes()));
auto get_args = [](size_t size, int N) {
size_t reductions = size / N;
int threads = 1024;
int blocks = std::min(1024UL, (reductions + threads - 1) / threads);
size_t reductions_per_block = std::max(
static_cast<size_t>(threads), (reductions + blocks - 1) / blocks);
size_t block_step = reductions_per_block * N_READS;
return std::make_tuple(blocks, threads, block_step);
};
int blocks, threads;
size_t block_step;
bool large = in.size() > N_READS * 1024;
array x = in;
// Large array so allocate an intermediate and accumulate there
if (large) {
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
array intermediate({blocks}, out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
encoder.add_temporary(intermediate);
encoder.set_input_array(x);
encoder.set_output_array(intermediate);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
kernel<<<blocks, threads, 0, stream>>>(
x.data<T>(), intermediate.data<U>(), block_step, x.size());
});
});
});
x = intermediate;
}
// Final reduction
{
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
encoder.set_input_array(x);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
kernel<<<blocks, threads, 0, stream>>>(
x.data<T>(), out.data<U>(), block_step, x.size());
});
});
});
}
}
} // namespace mlx::core

View File

@ -47,6 +47,12 @@ namespace mlx::core {
throw std::invalid_argument("Unknown reduce type."); \
}
void all_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type);
void segmented_reduce(
cu::CommandEncoder& encoder,
const array& in,

View File

@ -6,17 +6,65 @@
namespace mlx::core::cu {
template <size_t N>
struct uint_by_size;
template <>
struct uint_by_size<2> {
using type = uint16_t;
};
template <>
struct uint_by_size<4> {
using type = uint32_t;
};
template <>
struct uint_by_size<8> {
using type = unsigned long long int;
};
template <typename T, typename Op>
__device__ void atomic_reduce(T* x, T y) {
if constexpr (sizeof(T) == 1) {
using U = uint16_t;
U* x_int = (U*)((char*)x - ((size_t)x % 2));
int shift = ((char*)x - (char*)x_int) * 8;
int mask = 0xff << shift;
U old_val, new_val;
do {
old_val = *x_int;
T result = Op{}(static_cast<T>((old_val >> shift) & 0xff), y);
new_val = (old_val & ~mask) | (result << shift);
} while (atomicCAS(x_int, old_val, new_val) != old_val);
} else {
using U = typename uint_by_size<sizeof(T)>::type;
U* x_int = (U*)(x);
U old_val, new_val;
do {
old_val = *x_int;
T result = Op{}(*((T*)&old_val), y);
new_val = *((U*)&result);
} while (atomicCAS(x_int, old_val, new_val) != old_val);
}
}
// Reduce ops.
struct And {
__device__ bool operator()(bool a, bool b) {
return a && b;
}
__device__ void atomic_update(bool* x, bool y) {
atomic_reduce<bool, And>(x, y);
}
};
struct Or {
__device__ bool operator()(bool a, bool b) {
return a || b;
}
__device__ void atomic_update(bool* x, bool y) {
atomic_reduce<bool, Or>(x, y);
}
};
struct Sum {
@ -24,6 +72,23 @@ struct Sum {
__device__ T operator()(T a, T b) {
return a + b;
}
template <typename T>
__device__ void atomic_update(T* x, T y) {
atomic_reduce<T, Sum>(x, y);
}
__device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) {
atomicAdd(x, y);
}
__device__ void atomic_update(int* x, int y) {
atomicAdd(x, y);
}
__device__ void atomic_update(float* x, float y) {
atomicAdd(x, y);
}
};
struct Prod {
@ -31,6 +96,11 @@ struct Prod {
__device__ T operator()(T a, T b) {
return a * b;
}
template <typename T>
__device__ void atomic_update(T* x, T y) {
atomic_reduce<T, Prod>(x, y);
}
};
struct Min {
@ -38,6 +108,11 @@ struct Min {
__device__ T operator()(T a, T b) {
return a < b ? a : b;
}
template <typename T>
__device__ void atomic_update(T* x, T y) {
atomic_reduce<T, Min>(x, y);
}
};
struct Max {
@ -45,6 +120,11 @@ struct Max {
__device__ T operator()(T a, T b) {
return a > b ? a : b;
}
template <typename T>
__device__ void atomic_update(T* x, T y) {
atomic_reduce<T, Max>(x, y);
}
};
// Traits to get the result type of reduce op.
@ -120,7 +200,7 @@ template <typename T>
struct ReduceInit<Prod, T> {
static constexpr __host__ __device__ auto value() {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return T{1, 1};
return T{1, 0};
} else {
return typename ReduceResult<Prod, T>::type{1};
}