mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Comments
This commit is contained in:
@@ -96,43 +96,52 @@ void all_reduce(
|
||||
|
||||
int blocks, threads;
|
||||
size_t block_step;
|
||||
array x = in;
|
||||
size_t insize = in.size();
|
||||
Dtype dt = in.dtype();
|
||||
|
||||
// Cub doesn't like const pointers for load (sigh).
|
||||
void* indata = const_cast<void*>(in.data<void>());
|
||||
|
||||
// Large array so allocate an intermediate and accumulate there
|
||||
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
||||
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
|
||||
encoder.set_input_array(in);
|
||||
if (blocks > 1) {
|
||||
array intermediate({blocks}, out.dtype(), nullptr, {});
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
encoder.add_temporary(intermediate);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_output_array(intermediate);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
||||
MLX_SWITCH_ALL_TYPES(dt, CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using T = cuda_type_t<CTYPE>;
|
||||
using U = cu::ReduceResult<OP, T>::type;
|
||||
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
||||
kernel<<<blocks, threads, 0, stream>>>(
|
||||
x.data<T>(), intermediate.data<U>(), block_step, x.size());
|
||||
static_cast<T*>(indata),
|
||||
intermediate.data<U>(),
|
||||
block_step,
|
||||
insize);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Set the input for the next step and recalculate the blocks
|
||||
x = intermediate;
|
||||
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
||||
indata = intermediate.data<void>();
|
||||
dt = intermediate.dtype();
|
||||
insize = intermediate.size();
|
||||
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
|
||||
encoder.set_input_array(intermediate);
|
||||
}
|
||||
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
||||
MLX_SWITCH_ALL_TYPES(dt, CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using T = cuda_type_t<CTYPE>;
|
||||
using U = cu::ReduceResult<OP, T>::type;
|
||||
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
||||
kernel<<<blocks, threads, 0, stream>>>(
|
||||
x.data<T>(), out.data<U>(), block_step, x.size());
|
||||
static_cast<T*>(indata), out.data<U>(), block_step, insize);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -38,12 +38,18 @@ struct ColReduceArgs {
|
||||
const array& in,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes) {
|
||||
using ShapeVector = decltype(plan.shape);
|
||||
using StridesVector = decltype(plan.strides);
|
||||
|
||||
ShapeVector shape_vec;
|
||||
StridesVector strides_vec;
|
||||
|
||||
assert(!plan.shape.empty());
|
||||
reduction_size = plan.shape.back();
|
||||
reduction_stride = plan.strides.back();
|
||||
|
||||
int64_t stride_back = 1;
|
||||
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
|
||||
std::tie(shape_vec, strides_vec) = shapes_without_reduction_axes(in, axes);
|
||||
while (!shape_vec.empty() && stride_back < reduction_stride) {
|
||||
stride_back *= shape_vec.back();
|
||||
shape_vec.pop_back();
|
||||
@@ -54,8 +60,8 @@ struct ColReduceArgs {
|
||||
std::sort(indices.begin(), indices.end(), [&](int left, int right) {
|
||||
return strides_vec[left] > strides_vec[right];
|
||||
});
|
||||
decltype(shape_vec) sorted_shape;
|
||||
decltype(strides_vec) sorted_strides;
|
||||
ShapeVector sorted_shape;
|
||||
StridesVector sorted_strides;
|
||||
for (auto idx : indices) {
|
||||
sorted_shape.push_back(shape_vec[idx]);
|
||||
sorted_strides.push_back(strides_vec[idx]);
|
||||
@@ -206,26 +212,25 @@ void col_reduce_looped(
|
||||
// contiguously as possible.
|
||||
allocate_same_layout(out, in, axes);
|
||||
|
||||
// Just a way to get out of the constness because cub doesn't like it ...
|
||||
// (sigh)
|
||||
array x = in;
|
||||
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||
using T = cuda_type_t<CTYPE>;
|
||||
using U = cu::ReduceResult<OP, T>::type;
|
||||
|
||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||
T* indata = const_cast<T*>(in.data<T>());
|
||||
|
||||
constexpr int N_READS = 4;
|
||||
constexpr int BM = 32;
|
||||
constexpr int BN = 32;
|
||||
dim3 grid = output_grid_for_col_reduce(out, args, BN);
|
||||
int blocks = BM * BN / N_READS;
|
||||
auto kernel = cu::col_reduce_looped<T, U, OP, NDIM, BM, BN, N_READS>;
|
||||
kernel<<<grid, blocks, 0, stream>>>(x.data<T>(), out.data<U>(), args);
|
||||
kernel<<<grid, blocks, 0, stream>>>(indata, out.data<U>(), args);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -247,8 +252,7 @@ void col_reduce(
|
||||
// a subrow of the fast moving axis. For instance 32 elements.
|
||||
//
|
||||
// Notes: As in row reduce we opt to read as much in order as possible and
|
||||
// leave
|
||||
// transpositions as they are (contrary to our Metal backend).
|
||||
// leave transpositions as they are (contrary to our Metal backend).
|
||||
//
|
||||
// Moreover we need different kernels for short rows and tuning
|
||||
|
||||
|
||||
@@ -31,7 +31,6 @@ void init_reduce(
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
|
||||
@@ -146,7 +146,7 @@ inline void allocate_same_layout(
|
||||
auto fl = in.flags();
|
||||
fl.row_contiguous = rc;
|
||||
fl.col_contiguous = cc;
|
||||
fl.contiguous = data_size == out.size();
|
||||
fl.contiguous = true;
|
||||
out.set_data(
|
||||
allocator::malloc(out.nbytes()),
|
||||
data_size,
|
||||
|
||||
@@ -241,20 +241,19 @@ void row_reduce_simple(
|
||||
// kernel.
|
||||
allocate_same_layout(out, in, axes);
|
||||
|
||||
// Just a way to get out of the constness because cub doesn't like it ...
|
||||
// (sigh)
|
||||
array x = in;
|
||||
|
||||
// TODO: If out.size() < 1024 which will be a common case then write this in
|
||||
// 2 passes. Something like 32 * out.size() and then do a warp reduce.
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using T = cuda_type_t<CTYPE>;
|
||||
using U = cu::ReduceResult<OP, T>::type;
|
||||
|
||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||
T* indata = const_cast<T*>(in.data<T>());
|
||||
|
||||
// Calculate the grid and block dims
|
||||
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
@@ -271,7 +270,7 @@ void row_reduce_simple(
|
||||
|
||||
// Launch
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
x.data<T>(), out.data<U>(), out.size(), plan.shape.back());
|
||||
indata, out.data<U>(), out.size(), plan.shape.back());
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -291,20 +290,19 @@ void row_reduce_looped(
|
||||
// contiguously as possible.
|
||||
allocate_same_layout(out, in, axes);
|
||||
|
||||
// Just a way to get out of the constness because cub doesn't like it ...
|
||||
// (sigh)
|
||||
array x = in;
|
||||
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||
using T = cuda_type_t<CTYPE>;
|
||||
using U = cu::ReduceResult<OP, T>::type;
|
||||
|
||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||
T* indata = const_cast<T*>(in.data<T>());
|
||||
|
||||
// Calculate the grid and block dims
|
||||
args.sort_access_pattern(x, axes);
|
||||
args.sort_access_pattern(in, axes);
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
|
||||
int threads = std::min(1024UL, reductions);
|
||||
@@ -322,7 +320,7 @@ void row_reduce_looped(
|
||||
|
||||
// Launch
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
x.data<T>(), out.data<U>(), out.size(), args);
|
||||
indata, out.data<U>(), out.size(), args);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user