Add helpers and atomic kernel

This commit is contained in:
Angelos Katharopoulos 2025-06-21 12:37:35 -07:00
parent 880751a084
commit abdb21f27c
5 changed files with 394 additions and 196 deletions

View File

@ -5,11 +5,9 @@
namespace mlx::core {
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
Shape shape,
Strides strides,
const std::vector<int>& axes) {
auto shape = x.shape();
auto strides = x.strides();
for (int i = axes.size() - 1; i >= 0; i--) {
int a = axes[i];
shape.erase(shape.begin() + a);
@ -19,6 +17,15 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
return std::make_pair(shape, strides);
}
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes) {
auto shape = x.shape();
auto strides = x.strides();
return shapes_without_reduction_axes(
std::move(shape), std::move(strides), axes);
}
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&

View File

@ -51,5 +51,9 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes);
std::pair<Shape, Strides> shapes_without_reduction_axes(
Shape shape,
Strides strides,
const std::vector<int>& axes);
} // namespace mlx::core

View File

@ -15,6 +15,9 @@ namespace cg = cooperative_groups;
template <typename T, typename U, typename ReduceOp, int N = 4>
__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
// TODO: Process multiple "rows" in each thread
constexpr int M = 1;
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
@ -23,10 +26,8 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
ReduceOp op;
T vals[N];
U accs[N];
for (int i = 0; i < N; i++) {
accs[i] = init;
}
U accs[M];
accs[0] = init;
size_t start = grid.block_rank() * block_step;
size_t end = start + block_step;
@ -35,7 +36,7 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
for (size_t i = start; i + block.size() * N <= check; i += block.size() * N) {
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
for (int j = 0; j < N; j++) {
accs[j] = op(accs[j], __cast<U, T>(vals[j]));
accs[0] = op(accs[0], __cast<U, T>(vals[j]));
}
}
@ -45,26 +46,12 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
cub::LoadDirectBlocked(
block.thread_rank(), in + offset, vals, block_end, __cast<T, U>(init));
for (int i = 0; i < N; i++) {
accs[i] = op(accs[i], __cast<U, T>(vals[i]));
accs[0] = op(accs[0], __cast<U, T>(vals[i]));
}
}
for (int i = 1; i < N; i++) {
accs[0] = op(accs[0], accs[i]);
}
accs[0] = cg::reduce(warp, accs[0], op);
if (warp.meta_group_size() > 1) {
__shared__ U shared_accumulators[32];
if (warp.thread_rank() == 0) {
shared_accumulators[warp.meta_group_rank()] = accs[0];
}
block.sync();
accs[0] = (warp.thread_rank() < warp.meta_group_size())
? shared_accumulators[warp.thread_rank()]
: init;
accs[0] = cg::reduce(warp, accs[0], op);
}
block_reduce(block, warp, accs, shared_accumulators, op, init);
if (block.thread_rank() == 0) {
out[grid.block_rank()] = accs[0];

View File

@ -4,7 +4,14 @@
#include "mlx/backend/cuda/device/utils.cuh"
namespace mlx::core::cu {
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <size_t N>
struct uint_by_size;
@ -62,4 +69,66 @@ inline __device__ cuComplex __cast<cuComplex, bool>(bool x) {
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0);
}
} // namespace mlx::core::cu
template <typename T, int N, typename Block, typename Warp, typename Op>
inline __device__ void
block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) {
// First reduce in the current warp
for (int i = 0; i < N; i++) {
vals[i] = cg::reduce(warp, vals[i], op);
}
// Reduce across warps
if (warp.meta_group_size() > 1) {
if (warp.thread_rank() == 0) {
for (int i = 0; i < N; i++) {
smem[warp.meta_group_rank() * N + i] = vals[i];
}
}
block.sync();
if (warp.thread_rank() < warp.meta_group_size()) {
for (int i = 0; i < N; i++) {
vals[i] = smem[warp.thread_rank() * N + i];
}
} else {
for (int i = 0; i < N; i++) {
vals[i] = init;
}
}
for (int i = 0; i < N; i++) {
vals[i] = cg::reduce(warp, vals[i], op);
}
}
}
} // namespace cu
inline void allocate_same_layout(
array& out,
const array& in,
const std::vector<int>& axes) {
// Initialize out such that it matches in's layout. Basically we keep any
// transpositions as it were and that allows us either to skip finding the
// location of the output that matches the input or simply contiguous read or
// writes.
auto out_strides = in.strides();
for (auto ax : axes) {
for (auto& s : out_strides) {
if (s > in.strides(ax)) {
s /= in.shape(ax);
}
}
}
auto [data_size, rc, cc] = check_contiguity(out.shape(), out_strides);
auto fl = in.flags();
fl.row_contiguous = rc;
fl.col_contiguous = cc;
fl.contiguous = data_size == out.size();
out.set_data(
allocator::malloc(out.nbytes()),
data_size,
out_strides,
fl,
allocator::free);
}
} // namespace mlx::core

View File

@ -55,86 +55,109 @@ struct RowReduceArgs {
non_row_reductions *= reduce_shape[i];
}
}
// Convert shape and strides as if in was contiguous
void convert_shapes_to_contiguous(
const array& in,
const std::vector<int>& axes) {
auto shape_vec = in.shape();
auto strides_vec = in.strides();
size_t s = 1;
for (int i = in.ndim() - 1; i >= 0; i--) {
strides_vec[i] = s;
s *= shape_vec[i];
}
std::tie(shape_vec, strides_vec) =
shapes_without_reduction_axes(shape_vec, strides_vec, axes);
std::tie(shape_vec, strides_vec) =
collapse_contiguous_dims(shape_vec, strides_vec);
shape = const_param(shape_vec);
strides = const_param(strides_vec);
ndim = shape_vec.size();
}
};
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
__global__ void row_reduce_small(
const T* in,
U* out,
size_t out_size,
const __grid_constant__ RowReduceArgs args) {
size_t out_idx = cg::this_grid().thread_rank();
if (out_idx >= out_size) {
return;
}
Op op;
U total_val = ReduceInit<Op, T>::value();
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
for (size_t n = 0; n < args.non_row_reductions; n++) {
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
U vals[N_READS];
cub::LoadDirectBlocked(
r,
make_cast_iterator<U>(in + loop.location()),
vals,
args.row_size,
ReduceInit<Op, T>::value());
total_val = op(total_val, cub::ThreadReduce(vals, op));
}
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
out[out_idx] = total_val;
}
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
__global__ void row_reduce_small_warp(
const T* in,
U* out,
size_t out_size,
const __grid_constant__ RowReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
size_t out_idx = grid.thread_rank() / WARP_SIZE;
if (out_idx >= out_size) {
return;
}
Op op;
U total_val = ReduceInit<Op, T>::value();
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
for (size_t n = warp.thread_rank(); n < args.non_row_reductions;
n += WARP_SIZE) {
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
U vals[N_READS];
cub::LoadDirectBlocked(
r,
make_cast_iterator<U>(in + loop.location()),
vals,
args.row_size,
ReduceInit<Op, T>::value());
total_val = op(total_val, cub::ThreadReduce(vals, op));
}
loop.next(WARP_SIZE, args.reduce_shape.data(), args.reduce_strides.data());
}
total_val = cg::reduce(warp, total_val, op);
if (warp.thread_rank() == 0) {
out[out_idx] = total_val;
}
}
// template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
//__global__ void row_reduce_small(
// const T* in,
// U* out,
// size_t out_size,
// const __grid_constant__ RowReduceArgs args) {
// size_t out_idx = cg::this_grid().thread_rank();
// if (out_idx >= out_size) {
// return;
// }
//
// Op op;
//
// U total_val = ReduceInit<Op, T>::value();
// LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
//
// in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(),
// args.ndim);
//
// for (size_t n = 0; n < args.non_row_reductions; n++) {
// for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
// U vals[N_READS];
// cub::LoadDirectBlocked(
// r,
// make_cast_iterator<U>(in + loop.location()),
// vals,
// args.row_size,
// ReduceInit<Op, T>::value());
// total_val = op(total_val, cub::ThreadReduce(vals, op));
// }
// loop.next(args.reduce_shape.data(), args.reduce_strides.data());
// }
//
// out[out_idx] = total_val;
// }
//
// template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
//__global__ void row_reduce_small_warp(
// const T* in,
// U* out,
// size_t out_size,
// const __grid_constant__ RowReduceArgs args) {
// auto grid = cg::this_grid();
// auto block = cg::this_thread_block();
// auto warp = cg::tiled_partition<WARP_SIZE>(block);
//
// size_t out_idx = grid.thread_rank() / WARP_SIZE;
// if (out_idx >= out_size) {
// return;
// }
//
// Op op;
//
// U total_val = ReduceInit<Op, T>::value();
// LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
//
// in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(),
// args.ndim);
//
// for (size_t n = warp.thread_rank(); n < args.non_row_reductions;
// n += WARP_SIZE) {
// for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
// U vals[N_READS];
// cub::LoadDirectBlocked(
// r,
// make_cast_iterator<U>(in + loop.location()),
// vals,
// args.row_size,
// ReduceInit<Op, T>::value());
// total_val = op(total_val, cub::ThreadReduce(vals, op));
// }
// loop.next(WARP_SIZE, args.reduce_shape.data(),
// args.reduce_strides.data());
// }
//
// total_val = cg::reduce(warp, total_val, op);
//
// if (warp.thread_rank() == 0) {
// out[out_idx] = total_val;
// }
// }
template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
@ -153,59 +176,37 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
const size_t start_row =
min(n_rows - M, static_cast<size_t>(grid.block_rank() * M));
const size_t full_blocks = size / (block.size() * N);
const size_t final_offset = full_blocks * (block.size() * N);
in += start_row * size;
out += start_row;
int i = 0;
for (; i + block.size() * N <= size; i += block.size() * N) {
for (size_t r = 0; r < full_blocks; r++) {
for (int k = 0; k < M; k++) {
cub::LoadDirectBlockedVectorized<T, N>(
block.thread_rank(), in + k * size + i, vals[k]);
block.thread_rank(), in + k * size + r * (block.size() * N), vals[k]);
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
}
}
}
if (size > i) {
if (final_offset < size) {
for (int k = 0; k < M; k++) {
cub::LoadDirectBlocked(
block.thread_rank(),
in + k * size + i,
in + k * size + final_offset,
vals[k],
size,
__cast<T, U>(init));
for (int j = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
}
}
}
for (int i = 0; i < M; i++) {
accs[i] = cg::reduce(warp, accs[i], op);
}
if (warp.meta_group_size() > 1) {
__shared__ U shared_accumulators[32 * M];
if (warp.thread_rank() == 0) {
for (int i = 0; i < M; i++) {
shared_accumulators[warp.meta_group_rank() * M + i] = accs[i];
}
}
block.sync();
if (warp.thread_rank() < warp.meta_group_size()) {
for (int i = 0; i < M; i++) {
accs[i] = shared_accumulators[warp.thread_rank() * M + i];
}
} else {
for (int i = 0; i < M; i++) {
accs[i] = init;
}
}
for (int i = 0; i < M; i++) {
accs[i] = cg::reduce(warp, accs[i], op);
}
}
block_reduce(block, warp, accs, shared_accumulators, op, init);
if (block.thread_rank() == 0) {
if (grid.block_rank() * M + M <= n_rows) {
@ -226,7 +227,7 @@ template <
typename U,
typename Op,
int NDIM,
int BLOCK_DIM_X,
int BLOCK_DIM,
int N_READS = 4>
__global__ void row_reduce_looped(
T* in,
@ -237,27 +238,28 @@ __global__ void row_reduce_looped(
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
size_t out_idx = grid.thread_rank() / BLOCK_DIM_X;
if (out_idx >= out_size) {
return;
}
size_t out_idx = grid.block_rank();
Op op;
U total_val = ReduceInit<Op, T>::value();
U total[1];
U init = ReduceInit<Op, T>::value();
total[0] = init;
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
size_t full_blocks = args.row_size / (BLOCK_DIM_X * N_READS);
size_t final_offset = full_blocks * BLOCK_DIM_X * N_READS;
size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS);
size_t final_offset = full_blocks * BLOCK_DIM * N_READS;
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
for (size_t n = 0; n < args.non_row_reductions; n++) {
for (size_t r = 0; r < full_blocks; r++) {
T vals[N_READS];
cub::LoadDirectBlockedVectorized<T, N_READS>(
block.thread_rank(),
in + loop.location() + r * BLOCK_DIM_X * N_READS,
in + loop.location() + r * BLOCK_DIM * N_READS,
vals);
for (int i = 0; i < N_READS; i++) {
total_val = op(total_val, __cast<U, T>(vals[i]));
total[0] = op(total[0], __cast<U, T>(vals[i]));
}
}
if (final_offset < args.row_size) {
@ -267,26 +269,117 @@ __global__ void row_reduce_looped(
in + loop.location() + final_offset,
vals,
args.row_size - final_offset,
__cast<T, U>(ReduceInit<Op, T>::value()));
__cast<T, U>(init));
for (int i = 0; i < N_READS; i++) {
total_val = op(total_val, __cast<U, T>(vals[i]));
total[0] = op(total[0], __cast<U, T>(vals[i]));
}
}
// TODO: Maybe block.sync() here?
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
typedef cub::BlockReduce<U, BLOCK_DIM_X> BlockReduceT;
__shared__ typename BlockReduceT::TempStorage temp;
total_val = BlockReduceT(temp).Reduce(total_val, op);
__shared__ U shared_accumulators[32];
block_reduce(block, warp, total, shared_accumulators, op, init);
if (block.thread_rank() == 0) {
out[out_idx] = total_val;
out[out_idx] = total[0];
}
}
template <typename T, typename U, typename Op, int N = 4>
__global__ void reduce_initialize(U* out, size_t out_size) {
auto grid = cg::this_grid();
if (grid.thread_rank() * N + N <= out_size) {
for (int i = 0; i < N; i++) {
out[grid.thread_rank() * N + i] = ReduceInit<Op, T>::value();
}
} else {
for (int i = grid.thread_rank() * N; i < out_size; i++) {
out[i] = ReduceInit<Op, T>::value();
}
}
}
template <typename T, typename U, typename Op, int BLOCK_DIM, int N_READS = 4>
__global__ void row_reduce_atomics(
T* in,
U* out,
size_t out_size,
const __grid_constant__ RowReduceArgs args) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
size_t reduction_idx = grid.block_rank() / out_size;
size_t out_idx = grid.block_rank() % out_size;
Op op;
U total[1];
U init = ReduceInit<Op, T>::value();
total[0] = init;
size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS);
size_t final_offset = full_blocks * BLOCK_DIM * N_READS;
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
in += elem_to_loc(
reduction_idx,
args.reduce_shape.data(),
args.reduce_strides.data(),
args.reduce_ndim);
for (size_t r = 0; r < full_blocks; r++) {
T vals[N_READS];
cub::LoadDirectBlockedVectorized<T, N_READS>(
block.thread_rank(), in + r * BLOCK_DIM * N_READS, vals);
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], __cast<U, T>(vals[i]));
}
}
if (final_offset < args.row_size) {
T vals[N_READS];
cub::LoadDirectBlocked(
block.thread_rank(),
in + final_offset,
vals,
args.row_size - final_offset,
__cast<T, U>(init));
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], __cast<U, T>(vals[i]));
}
}
__shared__ U shared_accumulators[32];
block_reduce(block, warp, total, shared_accumulators, op, init);
if (block.thread_rank() == 0) {
op.atomic_update(out + out_idx, total[0]);
}
}
} // namespace cu
void reduce_initialize(
cu::CommandEncoder& encoder,
array& out,
Reduce::ReduceType reduce_type) {
constexpr int N_WRITES = 8;
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
auto kernel = cu::reduce_initialize<T, U, OP, N_WRITES>;
auto [grid, block] =
get_launch_args(kernel, out, out.size() >= 1UL << 31, N_WRITES);
kernel<<<grid, block, 0, stream>>>(out.data<U>(), out.size());
});
});
});
}
void row_reduce_simple(
cu::CommandEncoder& encoder,
const array& in,
@ -296,23 +389,9 @@ void row_reduce_simple(
const ReductionPlan& plan) {
constexpr int N_READS = 8;
// Initialize out such that its strides match in's layout (except the fastest
// moving axis)
auto out_strides = in.strides();
for (auto& s : out_strides) {
s /= plan.shape.back();
}
auto [data_size, rc, cc] = check_contiguity(out.shape(), out_strides);
auto fl = in.flags();
fl.row_contiguous = rc;
fl.col_contiguous = cc;
fl.contiguous = data_size == out.size();
out.set_data(
allocator::malloc(out.nbytes()),
data_size,
out_strides,
fl,
allocator::free);
// Allocate data for the output using in's layout to avoid elem_to_loc in the
// kernel.
allocate_same_layout(out, in, axes);
// Just a way to get out of the constness because cub doesn't like it ...
// (sigh)
@ -356,31 +435,13 @@ void row_reduce_looped(
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan) {
const ReductionPlan& plan,
cu::RowReduceArgs args) {
constexpr int N_READS = 8;
// Initialize out such that it matches in's layout. Basically we keep any
// transpositions as it were and that allows us to skip finding the location
// of the output that matches the input.
auto out_strides = in.strides();
for (auto ax : axes) {
for (auto& s : out_strides) {
if (s > in.strides(ax)) {
s /= in.shape(ax);
}
}
}
auto [data_size, rc, cc] = check_contiguity(out.shape(), out_strides);
auto fl = in.flags();
fl.row_contiguous = rc;
fl.col_contiguous = cc;
fl.contiguous = data_size == out.size();
out.set_data(
allocator::malloc(out.nbytes()),
data_size,
out_strides,
fl,
allocator::free);
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
// Just a way to get out of the constness because cub doesn't like it ...
// (sigh)
@ -395,7 +456,7 @@ void row_reduce_looped(
using U = cu::ReduceResult<OP, T>::type;
// Calculate the grid and block dims
cu::RowReduceArgs args(in, plan, axes);
args.convert_shapes_to_contiguous(x, axes);
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
size_t reductions = args.row_size / N_READS;
int threads = std::min(1024UL, reductions);
@ -419,6 +480,66 @@ void row_reduce_looped(
});
}
void row_reduce_atomics(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
cu::RowReduceArgs args) {
constexpr int N_READS = 8;
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
// Just a way to get out of the constness because cub doesn't like it ...
// (sigh)
array x = in;
// Initialize
reduce_initialize(encoder, out, reduce_type);
// Launch the reduction
encoder.set_input_array(x);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
args.convert_shapes_to_contiguous(x, axes);
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
if (grid.x * args.non_row_reductions < INT_MAX) {
grid.x *= args.non_row_reductions;
} else if (grid.y * args.non_row_reductions < 65536) {
grid.y *= args.non_row_reductions;
} else {
throw std::runtime_error(
"[row_reduce_atomics] Non-row reductions need to be factorized which is NYI");
}
size_t reductions = args.row_size / N_READS;
int threads = std::min(1024UL, reductions);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
dim3 block(threads, 1, 1);
// Pick the kernel
auto kernel = cu::row_reduce_atomics<T, U, OP, 32, N_READS>;
MLX_SWITCH_BLOCK_DIM(threads, THREADS, {
kernel = cu::row_reduce_atomics<T, U, OP, THREADS, N_READS>;
block.x = THREADS;
});
// Launch
kernel<<<grid, block, 0, stream>>>(
x.data<T>(), out.data<U>(), out.size(), args);
});
});
});
}
void row_reduce(
cu::CommandEncoder& encoder,
const array& in,
@ -430,10 +551,20 @@ void row_reduce(
// it has stride 1.
if (plan.shape.size() == 1) {
row_reduce_simple(encoder, in, out, reduce_type, axes, plan);
return;
}
// Make the args struct to help route to the best kernel
cu::RowReduceArgs args(in, plan, axes);
// Let's use atomics to increase parallelism
if (false && args.row_size < 512) {
row_reduce_atomics(
encoder, in, out, reduce_type, axes, plan, std::move(args));
}
// Fallback row reduce
row_reduce_looped(encoder, in, out, reduce_type, axes, plan);
row_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args));
// encoder.launch_kernel([&](cudaStream_t stream) {
// MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {