mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Add all reduce and atomic updates
This commit is contained in:
parent
ab7c310914
commit
9cf7ef1068
@ -29,6 +29,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
|
||||
|
@ -157,7 +157,7 @@ void binary_op_gpu_inplace(
|
||||
if (ndim <= 3) {
|
||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||
auto kernel =
|
||||
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
|
||||
cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, out_a, large);
|
||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||
|
@ -21,29 +21,15 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(!axes_.empty());
|
||||
assert(out.size() != in.size());
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
// out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
// encoder.set_input_array(in);
|
||||
// encoder.set_output_array(out);
|
||||
|
||||
// Fill out with init value.
|
||||
if (in.size() == 0) {
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type_, OP, {
|
||||
using InType = cuda_type_t<CTYPE>;
|
||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
||||
thrust::fill_n(
|
||||
cu::thrust_policy(stream),
|
||||
thrust::device_pointer_cast(out.data<OutType>()),
|
||||
out.data_size(),
|
||||
cu::ReduceInit<OP, InType>::value());
|
||||
});
|
||||
});
|
||||
});
|
||||
return;
|
||||
throw std::runtime_error("Should never reach here.");
|
||||
}
|
||||
|
||||
// Reduce.
|
||||
@ -59,8 +45,12 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
plan = get_reduction_plan(in, axes_);
|
||||
}
|
||||
|
||||
if ((plan.type == ContiguousAllReduce) ||
|
||||
(plan.type == ContiguousReduce && plan.shape.size() == 1)) {
|
||||
if (plan.type == ContiguousAllReduce) {
|
||||
all_reduce(encoder, in, out, reduce_type_);
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
|
||||
segmented_reduce(encoder, in, out, reduce_type_, axes_, plan);
|
||||
return;
|
||||
}
|
||||
|
160
mlx/backend/cuda/reduce/all_reduce.cu
Normal file
160
mlx/backend/cuda/reduce/all_reduce.cu
Normal file
@ -0,0 +1,160 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cub/block/block_load.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO: Should make a custom complex type
|
||||
template <typename U, typename T>
|
||||
inline __device__ U __cast(T x) {
|
||||
return static_cast<U>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ bool __cast<bool, cuComplex>(cuComplex x) {
|
||||
return x.x != 0 && x.y != 0;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ cuComplex __cast<cuComplex, bool>(bool x) {
|
||||
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T, typename U, typename ReduceOp, int N = 4>
|
||||
__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
const U init = cu::ReduceInit<ReduceOp, T>::value();
|
||||
ReduceOp op;
|
||||
|
||||
T vals[N];
|
||||
U accs[N];
|
||||
for (int i = 0; i < N; i++) {
|
||||
accs[i] = init;
|
||||
}
|
||||
|
||||
size_t start = grid.block_rank() * block_step;
|
||||
size_t end = start + block_step;
|
||||
size_t check = min(end, size);
|
||||
|
||||
for (size_t i = start; i + block.size() * N <= check; i += block.size() * N) {
|
||||
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
|
||||
for (int i = 0; i < N; i++) {
|
||||
accs[i] = op(accs[i], __cast<U, T>(vals[i]));
|
||||
}
|
||||
}
|
||||
|
||||
if (end > size) {
|
||||
size_t offset = end - block.size() * N;
|
||||
int block_end = size - offset;
|
||||
cub::LoadDirectBlocked(
|
||||
block.thread_rank(), in + offset, vals, block_end, __cast<T, U>(init));
|
||||
for (int i = 0; i < N; i++) {
|
||||
accs[i] = op(accs[i], __cast<U, T>(vals[i]));
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 1; i < N; i++) {
|
||||
accs[0] = op(accs[0], accs[i]);
|
||||
}
|
||||
accs[0] = cg::reduce(warp, accs[0], op);
|
||||
|
||||
__shared__ U shared_accumulators[32];
|
||||
if (warp.thread_rank() == 0) {
|
||||
shared_accumulators[warp.meta_group_rank()] = accs[0];
|
||||
}
|
||||
block.sync();
|
||||
accs[0] = (warp.thread_rank() < warp.meta_group_size())
|
||||
? shared_accumulators[warp.thread_rank()]
|
||||
: init;
|
||||
accs[0] = cg::reduce(warp, accs[0], op);
|
||||
|
||||
if (block.thread_rank() == 0) {
|
||||
out[grid.block_rank()] = accs[0];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void all_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type) {
|
||||
constexpr int N_READS = 4;
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto get_args = [](size_t size, int N) {
|
||||
size_t reductions = size / N;
|
||||
int threads = 1024;
|
||||
int blocks = std::min(1024UL, (reductions + threads - 1) / threads);
|
||||
size_t reductions_per_block = std::max(
|
||||
static_cast<size_t>(threads), (reductions + blocks - 1) / blocks);
|
||||
size_t block_step = reductions_per_block * N_READS;
|
||||
|
||||
return std::make_tuple(blocks, threads, block_step);
|
||||
};
|
||||
|
||||
int blocks, threads;
|
||||
size_t block_step;
|
||||
bool large = in.size() > N_READS * 1024;
|
||||
array x = in;
|
||||
|
||||
// Large array so allocate an intermediate and accumulate there
|
||||
if (large) {
|
||||
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
||||
array intermediate({blocks}, out.dtype(), nullptr, {});
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
encoder.add_temporary(intermediate);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_output_array(intermediate);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using T = cuda_type_t<CTYPE>;
|
||||
using U = cu::ReduceResult<OP, T>::type;
|
||||
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
||||
kernel<<<blocks, threads, 0, stream>>>(
|
||||
x.data<T>(), intermediate.data<U>(), block_step, x.size());
|
||||
});
|
||||
});
|
||||
});
|
||||
x = intermediate;
|
||||
}
|
||||
|
||||
// Final reduction
|
||||
{
|
||||
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using T = cuda_type_t<CTYPE>;
|
||||
using U = cu::ReduceResult<OP, T>::type;
|
||||
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
||||
kernel<<<blocks, threads, 0, stream>>>(
|
||||
x.data<T>(), out.data<U>(), block_step, x.size());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -47,6 +47,12 @@ namespace mlx::core {
|
||||
throw std::invalid_argument("Unknown reduce type."); \
|
||||
}
|
||||
|
||||
void all_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type);
|
||||
|
||||
void segmented_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
|
@ -6,17 +6,65 @@
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
template <size_t N>
|
||||
struct uint_by_size;
|
||||
template <>
|
||||
struct uint_by_size<2> {
|
||||
using type = uint16_t;
|
||||
};
|
||||
template <>
|
||||
struct uint_by_size<4> {
|
||||
using type = uint32_t;
|
||||
};
|
||||
template <>
|
||||
struct uint_by_size<8> {
|
||||
using type = unsigned long long int;
|
||||
};
|
||||
|
||||
template <typename T, typename Op>
|
||||
__device__ void atomic_reduce(T* x, T y) {
|
||||
if constexpr (sizeof(T) == 1) {
|
||||
using U = uint16_t;
|
||||
U* x_int = (U*)((char*)x - ((size_t)x % 2));
|
||||
int shift = ((char*)x - (char*)x_int) * 8;
|
||||
int mask = 0xff << shift;
|
||||
U old_val, new_val;
|
||||
do {
|
||||
old_val = *x_int;
|
||||
T result = Op{}(static_cast<T>((old_val >> shift) & 0xff), y);
|
||||
new_val = (old_val & ~mask) | (result << shift);
|
||||
} while (atomicCAS(x_int, old_val, new_val) != old_val);
|
||||
} else {
|
||||
using U = typename uint_by_size<sizeof(T)>::type;
|
||||
U* x_int = (U*)(x);
|
||||
U old_val, new_val;
|
||||
do {
|
||||
old_val = *x_int;
|
||||
T result = Op{}(*((T*)&old_val), y);
|
||||
new_val = *((U*)&result);
|
||||
} while (atomicCAS(x_int, old_val, new_val) != old_val);
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce ops.
|
||||
struct And {
|
||||
__device__ bool operator()(bool a, bool b) {
|
||||
return a && b;
|
||||
}
|
||||
|
||||
__device__ void atomic_update(bool* x, bool y) {
|
||||
atomic_reduce<bool, And>(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
struct Or {
|
||||
__device__ bool operator()(bool a, bool b) {
|
||||
return a || b;
|
||||
}
|
||||
|
||||
__device__ void atomic_update(bool* x, bool y) {
|
||||
atomic_reduce<bool, Or>(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
struct Sum {
|
||||
@ -24,6 +72,23 @@ struct Sum {
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void atomic_update(T* x, T y) {
|
||||
atomic_reduce<T, Sum>(x, y);
|
||||
}
|
||||
|
||||
__device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) {
|
||||
atomicAdd(x, y);
|
||||
}
|
||||
|
||||
__device__ void atomic_update(int* x, int y) {
|
||||
atomicAdd(x, y);
|
||||
}
|
||||
|
||||
__device__ void atomic_update(float* x, float y) {
|
||||
atomicAdd(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
struct Prod {
|
||||
@ -31,6 +96,11 @@ struct Prod {
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void atomic_update(T* x, T y) {
|
||||
atomic_reduce<T, Prod>(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
struct Min {
|
||||
@ -38,6 +108,11 @@ struct Min {
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void atomic_update(T* x, T y) {
|
||||
atomic_reduce<T, Min>(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
struct Max {
|
||||
@ -45,6 +120,11 @@ struct Max {
|
||||
__device__ T operator()(T a, T b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void atomic_update(T* x, T y) {
|
||||
atomic_reduce<T, Max>(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
// Traits to get the result type of reduce op.
|
||||
@ -120,7 +200,7 @@ template <typename T>
|
||||
struct ReduceInit<Prod, T> {
|
||||
static constexpr __host__ __device__ auto value() {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
return T{1, 1};
|
||||
return T{1, 0};
|
||||
} else {
|
||||
return typename ReduceResult<Prod, T>::type{1};
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user