mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-28 21:21:21 +08:00
[CUDA] Fix reductions (#2314)
This commit is contained in:
parent
2c11d10f8d
commit
772f471ff2
@ -5,6 +5,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.cuda
|
||||||
import torch.mps
|
import torch.mps
|
||||||
|
|
||||||
|
|
||||||
@ -44,8 +45,10 @@ def bench(f, *args):
|
|||||||
|
|
||||||
|
|
||||||
def sync_if_needed(x):
|
def sync_if_needed(x):
|
||||||
if x.device != torch.device("cpu"):
|
if x.device == torch.device("mps"):
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
|
elif x.device == torch.device("cuda"):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -99,6 +102,14 @@ def reduction(op, axis, x):
|
|||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sum_and_add(axis, x, y):
|
||||||
|
z = x.sum(axis=axis, keepdims=True)
|
||||||
|
for i in range(50):
|
||||||
|
z = (z + y).sum(axis=axis, keepdims=True)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def softmax(axis, x):
|
def softmax(axis, x):
|
||||||
ys = []
|
ys = []
|
||||||
@ -340,7 +351,11 @@ if __name__ == "__main__":
|
|||||||
args.axis.pop(0)
|
args.axis.pop(0)
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
device = "cpu" if args.cpu else "mps"
|
device = "mps"
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
if args.cpu:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
types = args.dtype
|
types = args.dtype
|
||||||
if not types:
|
if not types:
|
||||||
@ -460,5 +475,8 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "selu":
|
elif args.benchmark == "selu":
|
||||||
print(bench(selu, x))
|
print(bench(selu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "sum_and_add":
|
||||||
|
print(bench(sum_and_add, axis, *xs))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
||||||
|
@ -5,11 +5,9 @@
|
|||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||||
const array& x,
|
Shape shape,
|
||||||
|
Strides strides,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
auto shape = x.shape();
|
|
||||||
auto strides = x.strides();
|
|
||||||
|
|
||||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||||
int a = axes[i];
|
int a = axes[i];
|
||||||
shape.erase(shape.begin() + a);
|
shape.erase(shape.begin() + a);
|
||||||
@ -19,6 +17,15 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
|
|||||||
return std::make_pair(shape, strides);
|
return std::make_pair(shape, strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||||
|
const array& x,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
auto shape = x.shape();
|
||||||
|
auto strides = x.strides();
|
||||||
|
return shapes_without_reduction_axes(
|
||||||
|
std::move(shape), std::move(strides), axes);
|
||||||
|
}
|
||||||
|
|
||||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||||
// The data is all there and we are reducing over everything
|
// The data is all there and we are reducing over everything
|
||||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||||
|
@ -51,5 +51,9 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
|
|||||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||||
const array& x,
|
const array& x,
|
||||||
const std::vector<int>& axes);
|
const std::vector<int>& axes);
|
||||||
|
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||||
|
Shape shape,
|
||||||
|
Strides strides,
|
||||||
|
const std::vector<int>& axes);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -29,9 +29,10 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||||
|
@ -157,7 +157,7 @@ void binary_op_gpu_inplace(
|
|||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||||
auto kernel =
|
auto kernel =
|
||||||
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
|
cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] =
|
||||||
get_launch_args(kernel, out_a, large);
|
get_launch_args(kernel, out_a, large);
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
@ -21,28 +21,11 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(!axes_.empty());
|
assert(!axes_.empty());
|
||||||
assert(out.size() != in.size());
|
assert(out.size() != in.size());
|
||||||
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
|
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
encoder.set_input_array(in);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
|
|
||||||
// Fill out with init value.
|
|
||||||
if (in.size() == 0) {
|
if (in.size() == 0) {
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
init_reduce(encoder, in, out, reduce_type_);
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type_, OP, {
|
|
||||||
using InType = cuda_type_t<CTYPE>;
|
|
||||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
|
||||||
thrust::fill_n(
|
|
||||||
cu::thrust_policy(stream),
|
|
||||||
thrust::device_pointer_cast(out.data<OutType>()),
|
|
||||||
out.data_size(),
|
|
||||||
cu::ReduceInit<OP, InType>::value());
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -51,7 +34,19 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
// If it is a general reduce then copy the input to a contiguous array and
|
// If it is a general reduce then copy the input to a contiguous array and
|
||||||
// recompute the plan.
|
// recompute the plan.
|
||||||
if (plan.type == GeneralReduce) {
|
//
|
||||||
|
// TODO: Instead of copying we can use elem-to-loc to deal with broadcasting
|
||||||
|
// like we do in Metal. When it comes to broadcasted reduction axes
|
||||||
|
// some can be ignored eg for min/max.
|
||||||
|
bool broadcasted = false;
|
||||||
|
for (int i = 0, j = 0; i < in.ndim() && !broadcasted; i++) {
|
||||||
|
if (j < axes_.size() && axes_[j] == i) {
|
||||||
|
j++;
|
||||||
|
} else {
|
||||||
|
broadcasted = in.strides(i) == 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) {
|
||||||
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
||||||
copy_gpu(in, in_copy, CopyType::General, s);
|
copy_gpu(in, in_copy, CopyType::General, s);
|
||||||
encoder.add_temporary(in_copy);
|
encoder.add_temporary(in_copy);
|
||||||
@ -59,9 +54,8 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
plan = get_reduction_plan(in, axes_);
|
plan = get_reduction_plan(in, axes_);
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((plan.type == ContiguousAllReduce) ||
|
if (plan.type == ContiguousAllReduce) {
|
||||||
(plan.type == ContiguousReduce && plan.shape.size() == 1)) {
|
all_reduce(encoder, in, out, reduce_type_);
|
||||||
segmented_reduce(encoder, in, out, reduce_type_, axes_, plan);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
150
mlx/backend/cuda/reduce/all_reduce.cu
Normal file
150
mlx/backend/cuda/reduce/all_reduce.cu
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cooperative_groups/reduce.h>
|
||||||
|
#include <cub/block/block_load.cuh>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename T, typename U, typename ReduceOp, int N = 4>
|
||||||
|
__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
|
||||||
|
// TODO: Process multiple "rows" in each thread
|
||||||
|
constexpr int M = 1;
|
||||||
|
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
|
|
||||||
|
const U init = cu::ReduceInit<ReduceOp, T>::value();
|
||||||
|
ReduceOp op;
|
||||||
|
|
||||||
|
T vals[N];
|
||||||
|
U accs[M];
|
||||||
|
accs[0] = init;
|
||||||
|
|
||||||
|
size_t start = grid.block_rank() * block_step;
|
||||||
|
size_t end = start + block_step;
|
||||||
|
size_t check = min(end, size);
|
||||||
|
|
||||||
|
size_t i = start;
|
||||||
|
for (; i + block.size() * N <= check; i += block.size() * N) {
|
||||||
|
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
|
||||||
|
for (int j = 0; j < N; j++) {
|
||||||
|
accs[0] = op(accs[0], __cast<U, T>(vals[j]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i < check) {
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
block.thread_rank(), in + i, vals, check - i, __cast<T, U>(init));
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
accs[0] = op(accs[0], __cast<U, T>(vals[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__shared__ U shared_accumulators[32];
|
||||||
|
block_reduce(block, warp, accs, shared_accumulators, op, init);
|
||||||
|
|
||||||
|
if (block.thread_rank() == 0) {
|
||||||
|
out[grid.block_rank()] = accs[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void all_reduce(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type) {
|
||||||
|
constexpr int N_READS = 8;
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
auto get_args = [](size_t size, int N) {
|
||||||
|
int threads = std::min(512UL, (size + N - 1) / N);
|
||||||
|
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||||
|
int reductions_per_step = threads * N;
|
||||||
|
size_t steps_needed =
|
||||||
|
(size + reductions_per_step - 1) / reductions_per_step;
|
||||||
|
|
||||||
|
int blocks;
|
||||||
|
if (steps_needed < 32) {
|
||||||
|
blocks = 1;
|
||||||
|
} else if (steps_needed < 128) {
|
||||||
|
blocks = 32;
|
||||||
|
} else if (steps_needed < 512) {
|
||||||
|
blocks = 128;
|
||||||
|
} else if (steps_needed < 1024) {
|
||||||
|
blocks = 512;
|
||||||
|
} else {
|
||||||
|
blocks = 1024;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t steps_per_block = (steps_needed + blocks - 1) / blocks;
|
||||||
|
size_t block_step = steps_per_block * reductions_per_step;
|
||||||
|
|
||||||
|
return std::make_tuple(blocks, threads, block_step);
|
||||||
|
};
|
||||||
|
|
||||||
|
int blocks, threads;
|
||||||
|
size_t block_step;
|
||||||
|
size_t insize = in.size();
|
||||||
|
Dtype dt = in.dtype();
|
||||||
|
|
||||||
|
// Cub doesn't like const pointers for load (sigh).
|
||||||
|
void* indata = const_cast<void*>(in.data<void>());
|
||||||
|
|
||||||
|
// Large array so allocate an intermediate and accumulate there
|
||||||
|
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
if (blocks > 1) {
|
||||||
|
array intermediate({blocks}, out.dtype(), nullptr, {});
|
||||||
|
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||||
|
encoder.add_temporary(intermediate);
|
||||||
|
encoder.set_output_array(intermediate);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(dt, CTYPE, {
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
using T = cuda_type_t<CTYPE>;
|
||||||
|
using U = cu::ReduceResult<OP, T>::type;
|
||||||
|
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
||||||
|
kernel<<<blocks, threads, 0, stream>>>(
|
||||||
|
static_cast<T*>(indata),
|
||||||
|
intermediate.data<U>(),
|
||||||
|
block_step,
|
||||||
|
insize);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// Set the input for the next step and recalculate the blocks
|
||||||
|
indata = intermediate.data<void>();
|
||||||
|
dt = intermediate.dtype();
|
||||||
|
insize = intermediate.size();
|
||||||
|
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
|
||||||
|
encoder.set_input_array(intermediate);
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(dt, CTYPE, {
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
using T = cuda_type_t<CTYPE>;
|
||||||
|
using U = cu::ReduceResult<OP, T>::type;
|
||||||
|
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
||||||
|
kernel<<<blocks, threads, 0, stream>>>(
|
||||||
|
static_cast<T*>(indata), out.data<U>(), block_step, insize);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -1,5 +1,7 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
@ -36,19 +38,36 @@ struct ColReduceArgs {
|
|||||||
const array& in,
|
const array& in,
|
||||||
const ReductionPlan& plan,
|
const ReductionPlan& plan,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
|
using ShapeVector = decltype(plan.shape);
|
||||||
|
using StridesVector = decltype(plan.strides);
|
||||||
|
|
||||||
|
ShapeVector shape_vec;
|
||||||
|
StridesVector strides_vec;
|
||||||
|
|
||||||
assert(!plan.shape.empty());
|
assert(!plan.shape.empty());
|
||||||
reduction_size = plan.shape.back();
|
reduction_size = plan.shape.back();
|
||||||
reduction_stride = plan.strides.back();
|
reduction_stride = plan.strides.back();
|
||||||
|
|
||||||
int64_t stride_back = 1;
|
int64_t stride_back = 1;
|
||||||
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
|
std::tie(shape_vec, strides_vec) = shapes_without_reduction_axes(in, axes);
|
||||||
while (!shape_vec.empty() && stride_back < reduction_stride) {
|
while (!shape_vec.empty() && stride_back < reduction_stride) {
|
||||||
stride_back *= shape_vec.back();
|
stride_back *= shape_vec.back();
|
||||||
shape_vec.pop_back();
|
shape_vec.pop_back();
|
||||||
strides_vec.pop_back();
|
strides_vec.pop_back();
|
||||||
}
|
}
|
||||||
|
std::vector<int> indices(shape_vec.size());
|
||||||
|
std::iota(indices.begin(), indices.end(), 0);
|
||||||
|
std::sort(indices.begin(), indices.end(), [&](int left, int right) {
|
||||||
|
return strides_vec[left] > strides_vec[right];
|
||||||
|
});
|
||||||
|
ShapeVector sorted_shape;
|
||||||
|
StridesVector sorted_strides;
|
||||||
|
for (auto idx : indices) {
|
||||||
|
sorted_shape.push_back(shape_vec[idx]);
|
||||||
|
sorted_strides.push_back(strides_vec[idx]);
|
||||||
|
}
|
||||||
std::tie(shape_vec, strides_vec) =
|
std::tie(shape_vec, strides_vec) =
|
||||||
collapse_contiguous_dims(shape_vec, strides_vec);
|
collapse_contiguous_dims(sorted_shape, sorted_strides);
|
||||||
shape = const_param(shape_vec);
|
shape = const_param(shape_vec);
|
||||||
strides = const_param(strides_vec);
|
strides = const_param(strides_vec);
|
||||||
ndim = shape_vec.size();
|
ndim = shape_vec.size();
|
||||||
@ -64,86 +83,6 @@ struct ColReduceArgs {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
|
||||||
__global__ void col_reduce_small(
|
|
||||||
const T* in,
|
|
||||||
U* out,
|
|
||||||
const __grid_constant__ ColReduceArgs args) {
|
|
||||||
auto grid = cg::this_grid();
|
|
||||||
auto block = cg::this_thread_block();
|
|
||||||
|
|
||||||
int column =
|
|
||||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
|
||||||
if (column * N_READS >= args.reduction_stride) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
|
||||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
|
||||||
|
|
||||||
Op op;
|
|
||||||
U totals[N_READS];
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
totals[i] = ReduceInit<Op, T>::value();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read input to local.
|
|
||||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
|
||||||
loop.next(
|
|
||||||
block.thread_index().y,
|
|
||||||
args.reduce_shape.data(),
|
|
||||||
args.reduce_strides.data());
|
|
||||||
for (size_t r = block.thread_index().y;
|
|
||||||
r < args.non_col_reductions * args.reduction_size;
|
|
||||||
r += block.dim_threads().y) {
|
|
||||||
U vals[N_READS];
|
|
||||||
cub::LoadDirectBlocked(
|
|
||||||
column,
|
|
||||||
make_cast_iterator<U>(in + loop.location()),
|
|
||||||
vals,
|
|
||||||
args.reduction_stride,
|
|
||||||
ReduceInit<Op, T>::value());
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
totals[i] = op(vals[i], totals[i]);
|
|
||||||
}
|
|
||||||
loop.next(
|
|
||||||
block.dim_threads().y,
|
|
||||||
args.reduce_shape.data(),
|
|
||||||
args.reduce_strides.data());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do block reduce when each column has more than 1 element to reduce.
|
|
||||||
if (block.dim_threads().y > 1) {
|
|
||||||
__shared__ U shared_vals[32 * 8 * N_READS];
|
|
||||||
size_t col =
|
|
||||||
block.thread_index().y * block.dim_threads().x + block.thread_index().x;
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
shared_vals[col * N_READS + i] = totals[i];
|
|
||||||
}
|
|
||||||
block.sync();
|
|
||||||
if (block.thread_index().y == 0) {
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
totals[i] = shared_vals[block.thread_index().x * N_READS + i];
|
|
||||||
}
|
|
||||||
for (int j = 1; j < block.dim_threads().y; j++) {
|
|
||||||
col = j * block.dim_threads().x + block.thread_index().x;
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
totals[i] = op(shared_vals[col * N_READS + i], totals[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write result.
|
|
||||||
if (block.thread_index().y == 0) {
|
|
||||||
cub::StoreDirectBlocked(
|
|
||||||
column,
|
|
||||||
out + out_idx * args.reduction_stride,
|
|
||||||
totals,
|
|
||||||
args.reduction_stride);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename U,
|
typename U,
|
||||||
@ -152,67 +91,94 @@ template <
|
|||||||
int BM,
|
int BM,
|
||||||
int BN,
|
int BN,
|
||||||
int N_READS = 4>
|
int N_READS = 4>
|
||||||
__global__ void col_reduce_looped(
|
__global__ void
|
||||||
const T* in,
|
col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||||
U* out,
|
|
||||||
const __grid_constant__ ColReduceArgs args) {
|
|
||||||
auto grid = cg::this_grid();
|
auto grid = cg::this_grid();
|
||||||
auto block = cg::this_thread_block();
|
auto block = cg::this_thread_block();
|
||||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
|
|
||||||
constexpr int n_warps = BN / N_READS;
|
constexpr int threads_per_row = BN / N_READS;
|
||||||
|
|
||||||
int out_idx = grid.block_rank() / grid.dim_blocks().x;
|
// Compute the indices for the tile
|
||||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
size_t tile_idx = grid.block_rank();
|
||||||
|
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
|
||||||
|
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
|
||||||
|
|
||||||
|
// Compute the indices for the thread within the tile
|
||||||
|
short thread_x = block.thread_rank() % threads_per_row;
|
||||||
|
short thread_y = block.thread_rank() / threads_per_row;
|
||||||
|
|
||||||
|
// Move the input pointer
|
||||||
|
in += elem_to_loc(tile_y, args.shape.data(), args.strides.data(), args.ndim) +
|
||||||
|
tile_x * BN;
|
||||||
|
|
||||||
|
// Initialize the running totals
|
||||||
Op op;
|
Op op;
|
||||||
U totals[N_READS];
|
U totals[N_READS];
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
totals[i] = ReduceInit<Op, T>::value();
|
totals[i] = ReduceInit<Op, T>::value();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read input to local.
|
|
||||||
int r = block.thread_rank() / n_warps;
|
|
||||||
int column = block.thread_rank() % n_warps;
|
|
||||||
int in_offset = grid.block_index().x * BN;
|
|
||||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
loop.next(r, args.reduce_shape.data(), args.reduce_strides.data());
|
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
for (; r < args.non_col_reductions * args.reduction_size; r += BM) {
|
size_t total = args.non_col_reductions * args.reduction_size;
|
||||||
U vals[N_READS];
|
if (tile_x * BN + BN <= args.reduction_stride) {
|
||||||
cub::LoadDirectBlocked(
|
if (args.reduction_stride % N_READS == 0) {
|
||||||
column,
|
for (size_t r = thread_y; r < total; r += BM) {
|
||||||
make_cast_iterator<U>(in + loop.location() + in_offset),
|
T vals[N_READS];
|
||||||
vals,
|
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
|
||||||
args.reduction_stride - in_offset,
|
for (int i = 0; i < N_READS; i++) {
|
||||||
ReduceInit<Op, T>::value());
|
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
|
||||||
for (int i = 0; i < N_READS; i++) {
|
}
|
||||||
totals[i] = op(vals[i], totals[i]);
|
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t r = thread_y; r < total; r += BM) {
|
||||||
|
T vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
|
||||||
|
}
|
||||||
|
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t r = thread_y; r < total; r += BM) {
|
||||||
|
T vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
thread_x,
|
||||||
|
in + loop.location(),
|
||||||
|
vals,
|
||||||
|
args.reduction_stride - tile_x * BN,
|
||||||
|
__cast<T, U>(ReduceInit<Op, T>::value()));
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
|
||||||
|
}
|
||||||
|
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
}
|
}
|
||||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do warp reduce for each output.
|
// Do warp reduce for each output.
|
||||||
constexpr int n_outputs = BN / n_warps;
|
constexpr int n_outputs = BN / threads_per_row;
|
||||||
static_assert(BM == 32 && n_outputs == N_READS);
|
static_assert(BM == 32 && n_outputs == N_READS);
|
||||||
__shared__ U shared_vals[BM * BN];
|
__shared__ U shared_vals[BM * BN];
|
||||||
size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS;
|
short s_idx = thread_y * BN + thread_x * N_READS;
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
shared_vals[col + i] = totals[i];
|
shared_vals[s_idx + i] = totals[i];
|
||||||
}
|
}
|
||||||
block.sync();
|
block.sync();
|
||||||
col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
|
s_idx = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs;
|
||||||
for (int i = 0; i < n_outputs; i++) {
|
for (int i = 0; i < n_outputs; i++) {
|
||||||
totals[i] = cg::reduce(warp, shared_vals[col + i], op);
|
totals[i] = cg::reduce(warp, shared_vals[s_idx + i], op);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write result.
|
// Write result.
|
||||||
if (warp.thread_rank() == 0) {
|
if (warp.thread_rank() == 0) {
|
||||||
size_t out_offset = grid.block_index().x * BN;
|
|
||||||
cub::StoreDirectBlocked(
|
cub::StoreDirectBlocked(
|
||||||
warp.meta_group_rank(),
|
warp.meta_group_rank(),
|
||||||
out + out_idx * args.reduction_stride + out_offset,
|
out + tile_y * args.reduction_stride + tile_x * BN,
|
||||||
totals,
|
totals,
|
||||||
args.reduction_stride - out_offset);
|
args.reduction_stride - tile_x * BN);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -220,14 +186,55 @@ __global__ void col_reduce_looped(
|
|||||||
|
|
||||||
inline auto output_grid_for_col_reduce(
|
inline auto output_grid_for_col_reduce(
|
||||||
const array& out,
|
const array& out,
|
||||||
const cu::ColReduceArgs& args) {
|
const cu::ColReduceArgs& args,
|
||||||
auto out_shape = out.shape();
|
int bn) {
|
||||||
auto out_strides = out.strides();
|
int gx, gy = 1;
|
||||||
while (!out_shape.empty() && out_strides.back() < args.reduction_stride) {
|
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
|
||||||
out_shape.pop_back();
|
size_t n_outer_blocks = out.size() / args.reduction_stride;
|
||||||
out_strides.pop_back();
|
size_t n_blocks = n_outer_blocks * n_inner_blocks;
|
||||||
|
while (n_blocks / gy > INT32_MAX) {
|
||||||
|
gy *= 2;
|
||||||
}
|
}
|
||||||
return get_2d_grid_dims(out_shape, out_strides);
|
gx = cuda::ceil_div(n_blocks, gy);
|
||||||
|
|
||||||
|
return dim3(gx, gy, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void col_reduce_looped(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan,
|
||||||
|
cu::ColReduceArgs args) {
|
||||||
|
// Allocate data for the output using in's layout to access them as
|
||||||
|
// contiguously as possible.
|
||||||
|
allocate_same_layout(out, in, axes);
|
||||||
|
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||||
|
using T = cuda_type_t<CTYPE>;
|
||||||
|
using U = cu::ReduceResult<OP, T>::type;
|
||||||
|
|
||||||
|
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||||
|
T* indata = const_cast<T*>(in.data<T>());
|
||||||
|
|
||||||
|
constexpr int N_READS = 4;
|
||||||
|
constexpr int BM = 32;
|
||||||
|
constexpr int BN = 32;
|
||||||
|
dim3 grid = output_grid_for_col_reduce(out, args, BN);
|
||||||
|
int blocks = BM * BN / N_READS;
|
||||||
|
auto kernel = cu::col_reduce_looped<T, U, OP, NDIM, BM, BN, N_READS>;
|
||||||
|
kernel<<<grid, blocks, 0, stream>>>(indata, out.data<U>(), args);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void col_reduce(
|
void col_reduce(
|
||||||
@ -237,42 +244,23 @@ void col_reduce(
|
|||||||
Reduce::ReduceType reduce_type,
|
Reduce::ReduceType reduce_type,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const ReductionPlan& plan) {
|
const ReductionPlan& plan) {
|
||||||
|
// Current col reduce options
|
||||||
|
//
|
||||||
|
// - col_reduce_looped
|
||||||
|
//
|
||||||
|
// It is a general strided reduce. Each threadblock computes the output for
|
||||||
|
// a subrow of the fast moving axis. For instance 32 elements.
|
||||||
|
//
|
||||||
|
// Notes: As in row reduce we opt to read as much in order as possible and
|
||||||
|
// leave transpositions as they are (contrary to our Metal backend).
|
||||||
|
//
|
||||||
|
// Moreover we need different kernels for short rows and tuning
|
||||||
|
|
||||||
|
// Make the args struct to help route to the best kernel
|
||||||
cu::ColReduceArgs args(in, plan, axes);
|
cu::ColReduceArgs args(in, plan, axes);
|
||||||
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
// Fallback col reduce
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
|
||||||
using InType = cuda_type_t<CTYPE>;
|
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
|
||||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
|
||||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
|
||||||
constexpr int N_READS = 4;
|
|
||||||
dim3 block_dims;
|
|
||||||
dim3 num_blocks = output_grid_for_col_reduce(out, args);
|
|
||||||
num_blocks.z = num_blocks.y;
|
|
||||||
num_blocks.y = num_blocks.x;
|
|
||||||
auto kernel =
|
|
||||||
cu::col_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
|
||||||
size_t total = args.non_col_reductions * args.reduction_size;
|
|
||||||
if (total < 32) {
|
|
||||||
size_t stride_blocks =
|
|
||||||
cuda::ceil_div(args.reduction_stride, N_READS);
|
|
||||||
block_dims.x = std::min(stride_blocks, 32ul);
|
|
||||||
block_dims.y = std::min(total, 8ul);
|
|
||||||
num_blocks.x = cuda::ceil_div(stride_blocks, block_dims.x);
|
|
||||||
} else {
|
|
||||||
constexpr int BM = 32;
|
|
||||||
constexpr int BN = 32;
|
|
||||||
block_dims.x = BM * BN / N_READS;
|
|
||||||
num_blocks.x = cuda::ceil_div(args.reduction_stride, BN);
|
|
||||||
kernel = cu::
|
|
||||||
col_reduce_looped<InType, OutType, OP, NDIM, BM, BN, N_READS>;
|
|
||||||
}
|
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
|
||||||
in.data<InType>(), out.data<OutType>(), args);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
50
mlx/backend/cuda/reduce/init_reduce.cu
Normal file
50
mlx/backend/cuda/reduce/init_reduce.cu
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
__global__ void init_reduce(U* out, size_t size) {
|
||||||
|
auto index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
out[index] = ReduceInit<Op, T>::value();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void init_reduce(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type) {
|
||||||
|
// Allocate if needed
|
||||||
|
if (out.data_shared_ptr() == nullptr) {
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
using T = cuda_type_t<CTYPE>;
|
||||||
|
using U = cu::ReduceResult<OP, T>::type;
|
||||||
|
auto kernel = cu::init_reduce<T, U, OP>;
|
||||||
|
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
|
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
|
||||||
|
grid.x = (grid.x + 1023) / 1024;
|
||||||
|
kernel<<<grid, block, 0, stream>>>(out.data<U>(), out.size());
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -47,13 +47,11 @@ namespace mlx::core {
|
|||||||
throw std::invalid_argument("Unknown reduce type."); \
|
throw std::invalid_argument("Unknown reduce type."); \
|
||||||
}
|
}
|
||||||
|
|
||||||
void segmented_reduce(
|
void all_reduce(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
const array& in,
|
const array& in,
|
||||||
array& out,
|
array& out,
|
||||||
Reduce::ReduceType reduce_type,
|
Reduce::ReduceType reduce_type);
|
||||||
const std::vector<int>& axes,
|
|
||||||
const ReductionPlan& plan);
|
|
||||||
|
|
||||||
void row_reduce(
|
void row_reduce(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
@ -71,4 +69,10 @@ void col_reduce(
|
|||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const ReductionPlan& plan);
|
const ReductionPlan& plan);
|
||||||
|
|
||||||
|
void init_reduce(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -3,48 +3,89 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/reduce/reduce_utils.cuh"
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
// Reduce ops.
|
// Reduce ops.
|
||||||
struct And {
|
struct And {
|
||||||
__device__ bool operator()(bool a, bool b) {
|
__device__ __forceinline__ bool operator()(bool a, bool b) {
|
||||||
return a && b;
|
return a && b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__device__ void atomic_update(bool* x, bool y) {
|
||||||
|
atomic_reduce<bool, And>(x, y);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Or {
|
struct Or {
|
||||||
__device__ bool operator()(bool a, bool b) {
|
__device__ __forceinline__ bool operator()(bool a, bool b) {
|
||||||
return a || b;
|
return a || b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__device__ void atomic_update(bool* x, bool y) {
|
||||||
|
atomic_reduce<bool, Or>(x, y);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Sum {
|
struct Sum {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T a, T b) {
|
__device__ __forceinline__ T operator()(T a, T b) {
|
||||||
return a + b;
|
return a + b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ void atomic_update(T* x, T y) {
|
||||||
|
atomic_reduce<T, Sum>(x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) {
|
||||||
|
atomicAdd(x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ void atomic_update(int* x, int y) {
|
||||||
|
atomicAdd(x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ void atomic_update(float* x, float y) {
|
||||||
|
atomicAdd(x, y);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Prod {
|
struct Prod {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T a, T b) {
|
__device__ __forceinline__ T operator()(T a, T b) {
|
||||||
return a * b;
|
return a * b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ void atomic_update(T* x, T y) {
|
||||||
|
atomic_reduce<T, Prod>(x, y);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Min {
|
struct Min {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T a, T b) {
|
__device__ __forceinline__ T operator()(T a, T b) {
|
||||||
return a < b ? a : b;
|
return a < b ? a : b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ void atomic_update(T* x, T y) {
|
||||||
|
atomic_reduce<T, Min>(x, y);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Max {
|
struct Max {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T a, T b) {
|
__device__ __forceinline__ T operator()(T a, T b) {
|
||||||
return a > b ? a : b;
|
return a > b ? a : b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ void atomic_update(T* x, T y) {
|
||||||
|
atomic_reduce<T, Max>(x, y);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Traits to get the result type of reduce op.
|
// Traits to get the result type of reduce op.
|
||||||
@ -120,7 +161,7 @@ template <typename T>
|
|||||||
struct ReduceInit<Prod, T> {
|
struct ReduceInit<Prod, T> {
|
||||||
static constexpr __host__ __device__ auto value() {
|
static constexpr __host__ __device__ auto value() {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
return T{1, 1};
|
return T{1, 0};
|
||||||
} else {
|
} else {
|
||||||
return typename ReduceResult<Prod, T>::type{1};
|
return typename ReduceResult<Prod, T>::type{1};
|
||||||
}
|
}
|
||||||
|
158
mlx/backend/cuda/reduce/reduce_utils.cuh
Normal file
158
mlx/backend/cuda/reduce/reduce_utils.cuh
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cooperative_groups/reduce.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <size_t N>
|
||||||
|
struct uint_by_size;
|
||||||
|
template <>
|
||||||
|
struct uint_by_size<2> {
|
||||||
|
using type = uint16_t;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct uint_by_size<4> {
|
||||||
|
using type = uint32_t;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct uint_by_size<8> {
|
||||||
|
using type = unsigned long long int;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename Op>
|
||||||
|
__device__ void atomic_reduce(T* x, T y) {
|
||||||
|
if constexpr (sizeof(T) == 1) {
|
||||||
|
using U = uint16_t;
|
||||||
|
U* x_int = (U*)((char*)x - ((size_t)x % 2));
|
||||||
|
int shift = ((char*)x - (char*)x_int) * 8;
|
||||||
|
int mask = 0xff << shift;
|
||||||
|
U old_val, new_val;
|
||||||
|
do {
|
||||||
|
old_val = *x_int;
|
||||||
|
T result = Op{}(static_cast<T>((old_val >> shift) & 0xff), y);
|
||||||
|
new_val = (old_val & ~mask) | (result << shift);
|
||||||
|
} while (atomicCAS(x_int, old_val, new_val) != old_val);
|
||||||
|
} else {
|
||||||
|
using U = typename uint_by_size<sizeof(T)>::type;
|
||||||
|
U* x_int = (U*)(x);
|
||||||
|
U old_val, new_val;
|
||||||
|
do {
|
||||||
|
old_val = *x_int;
|
||||||
|
T result = Op{}(*((T*)&old_val), y);
|
||||||
|
new_val = *((U*)&result);
|
||||||
|
} while (atomicCAS(x_int, old_val, new_val) != old_val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Should make a custom complex type
|
||||||
|
template <typename U, typename T>
|
||||||
|
inline __device__ U __cast(T x) {
|
||||||
|
return static_cast<U>(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline __device__ bool __cast<bool, cuComplex>(cuComplex x) {
|
||||||
|
return x.x != 0 && x.y != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline __device__ cuComplex __cast<cuComplex, bool>(bool x) {
|
||||||
|
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int N, typename Block, typename Warp, typename Op>
|
||||||
|
inline __device__ void
|
||||||
|
block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) {
|
||||||
|
// First reduce in the current warp
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
vals[i] = cg::reduce(warp, vals[i], op);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reduce across warps
|
||||||
|
if (warp.meta_group_size() > 1) {
|
||||||
|
if (warp.thread_rank() == 0) {
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
smem[warp.meta_group_rank() * N + i] = vals[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
if (warp.thread_rank() < warp.meta_group_size()) {
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
vals[i] = smem[warp.thread_rank() * N + i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
vals[i] = init;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
vals[i] = cg::reduce(warp, vals[i], op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
inline void allocate_same_layout(
|
||||||
|
array& out,
|
||||||
|
const array& in,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
if (in.flags().row_contiguous) {
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (out.ndim() < in.ndim()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Reduction without keepdims only supported for row-contiguous inputs");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the transpositions applied to in in order to apply them to out.
|
||||||
|
std::vector<int> axis_order(in.ndim());
|
||||||
|
std::iota(axis_order.begin(), axis_order.end(), 0);
|
||||||
|
std::sort(axis_order.begin(), axis_order.end(), [&](int left, int right) {
|
||||||
|
return in.strides(left) > in.strides(right);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Transpose the shape and calculate the strides
|
||||||
|
Shape out_shape(in.ndim());
|
||||||
|
Strides out_strides(in.ndim(), 1);
|
||||||
|
for (int i = 0; i < in.ndim(); i++) {
|
||||||
|
out_shape[i] = out.shape(axis_order[i]);
|
||||||
|
}
|
||||||
|
for (int i = in.ndim() - 2; i >= 0; i--) {
|
||||||
|
out_strides[i] = out_shape[i + 1] * out_strides[i + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reverse the axis order to get the final strides
|
||||||
|
Strides final_strides(in.ndim());
|
||||||
|
for (int i = 0; i < in.ndim(); i++) {
|
||||||
|
final_strides[axis_order[i]] = out_strides[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the resulting contiguity and do the memory allocation
|
||||||
|
auto [data_size, rc, cc] = check_contiguity(out.shape(), final_strides);
|
||||||
|
auto fl = in.flags();
|
||||||
|
fl.row_contiguous = rc;
|
||||||
|
fl.col_contiguous = cc;
|
||||||
|
fl.contiguous = true;
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc(out.nbytes()),
|
||||||
|
data_size,
|
||||||
|
final_strides,
|
||||||
|
fl,
|
||||||
|
allocator::free);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -1,5 +1,7 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
@ -55,84 +57,108 @@ struct RowReduceArgs {
|
|||||||
non_row_reductions *= reduce_shape[i];
|
non_row_reductions *= reduce_shape[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert shape and strides as if in was contiguous
|
||||||
|
void sort_access_pattern(const array& in, const std::vector<int>& axes) {
|
||||||
|
auto shape_vec = in.shape();
|
||||||
|
auto strides_vec = in.strides();
|
||||||
|
std::tie(shape_vec, strides_vec) =
|
||||||
|
shapes_without_reduction_axes(shape_vec, strides_vec, axes);
|
||||||
|
std::vector<int> indices(shape_vec.size());
|
||||||
|
std::iota(indices.begin(), indices.end(), 0);
|
||||||
|
std::sort(indices.begin(), indices.end(), [&](int left, int right) {
|
||||||
|
return strides_vec[left] > strides_vec[right];
|
||||||
|
});
|
||||||
|
decltype(shape_vec) sorted_shape;
|
||||||
|
decltype(strides_vec) sorted_strides;
|
||||||
|
for (auto idx : indices) {
|
||||||
|
sorted_shape.push_back(shape_vec[idx]);
|
||||||
|
sorted_strides.push_back(strides_vec[idx]);
|
||||||
|
}
|
||||||
|
std::tie(shape_vec, strides_vec) =
|
||||||
|
collapse_contiguous_dims(sorted_shape, sorted_strides);
|
||||||
|
shape = const_param(shape_vec);
|
||||||
|
strides = const_param(strides_vec);
|
||||||
|
ndim = shape_vec.size();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
|
||||||
__global__ void row_reduce_small(
|
__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
||||||
const T* in,
|
|
||||||
U* out,
|
|
||||||
size_t out_size,
|
|
||||||
const __grid_constant__ RowReduceArgs args) {
|
|
||||||
size_t out_idx = cg::this_grid().thread_rank();
|
|
||||||
if (out_idx >= out_size) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
Op op;
|
|
||||||
|
|
||||||
U total_val = ReduceInit<Op, T>::value();
|
|
||||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
|
||||||
|
|
||||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
|
||||||
|
|
||||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
|
||||||
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
|
|
||||||
U vals[N_READS];
|
|
||||||
cub::LoadDirectBlocked(
|
|
||||||
r,
|
|
||||||
make_cast_iterator<U>(in + loop.location()),
|
|
||||||
vals,
|
|
||||||
args.row_size,
|
|
||||||
ReduceInit<Op, T>::value());
|
|
||||||
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
|
||||||
}
|
|
||||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
|
||||||
}
|
|
||||||
|
|
||||||
out[out_idx] = total_val;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
|
||||||
__global__ void row_reduce_small_warp(
|
|
||||||
const T* in,
|
|
||||||
U* out,
|
|
||||||
size_t out_size,
|
|
||||||
const __grid_constant__ RowReduceArgs args) {
|
|
||||||
auto grid = cg::this_grid();
|
auto grid = cg::this_grid();
|
||||||
auto block = cg::this_thread_block();
|
auto block = cg::this_thread_block();
|
||||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
|
|
||||||
size_t out_idx = grid.thread_rank() / WARP_SIZE;
|
const U init = cu::ReduceInit<ReduceOp, T>::value();
|
||||||
if (out_idx >= out_size) {
|
ReduceOp op;
|
||||||
return;
|
|
||||||
|
T vals[M][N];
|
||||||
|
U accs[M];
|
||||||
|
for (int i = 0; i < M; i++) {
|
||||||
|
accs[i] = init;
|
||||||
}
|
}
|
||||||
|
|
||||||
Op op;
|
const size_t start_row =
|
||||||
|
min(n_rows - M, static_cast<size_t>(grid.block_rank() * M));
|
||||||
|
const size_t full_blocks = size / (block.size() * N);
|
||||||
|
const size_t final_offset = full_blocks * (block.size() * N);
|
||||||
|
in += start_row * size;
|
||||||
|
out += start_row;
|
||||||
|
|
||||||
U total_val = ReduceInit<Op, T>::value();
|
if (size % N == 0) {
|
||||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
for (size_t r = 0; r < full_blocks; r++) {
|
||||||
|
for (int k = 0; k < M; k++) {
|
||||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
cub::LoadDirectBlockedVectorized<T, N>(
|
||||||
|
block.thread_rank(),
|
||||||
for (size_t n = warp.thread_rank(); n < args.non_row_reductions;
|
in + k * size + r * (block.size() * N),
|
||||||
n += WARP_SIZE) {
|
vals[k]);
|
||||||
for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
|
for (int j = 0; j < N; j++) {
|
||||||
U vals[N_READS];
|
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
|
||||||
cub::LoadDirectBlocked(
|
}
|
||||||
r,
|
}
|
||||||
make_cast_iterator<U>(in + loop.location()),
|
}
|
||||||
vals,
|
} else {
|
||||||
args.row_size,
|
for (size_t r = 0; r < full_blocks; r++) {
|
||||||
ReduceInit<Op, T>::value());
|
for (int k = 0; k < M; k++) {
|
||||||
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
cub::LoadDirectBlocked(
|
||||||
|
block.thread_rank(),
|
||||||
|
in + k * size + r * (block.size() * N),
|
||||||
|
vals[k]);
|
||||||
|
for (int j = 0; j < N; j++) {
|
||||||
|
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
loop.next(WARP_SIZE, args.reduce_shape.data(), args.reduce_strides.data());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
total_val = cg::reduce(warp, total_val, op);
|
if (final_offset < size) {
|
||||||
|
for (int k = 0; k < M; k++) {
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
block.thread_rank(),
|
||||||
|
in + k * size + final_offset,
|
||||||
|
vals[k],
|
||||||
|
size,
|
||||||
|
__cast<T, U>(init));
|
||||||
|
for (int j = 0; j < N; j++) {
|
||||||
|
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (warp.thread_rank() == 0) {
|
__shared__ U shared_accumulators[32 * M];
|
||||||
out[out_idx] = total_val;
|
block_reduce(block, warp, accs, shared_accumulators, op, init);
|
||||||
|
|
||||||
|
if (block.thread_rank() == 0) {
|
||||||
|
if (grid.block_rank() * M + M <= n_rows) {
|
||||||
|
for (int i = 0; i < M; i++) {
|
||||||
|
out[i] = accs[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
short offset = grid.block_rank() * M + M - n_rows;
|
||||||
|
for (int i = offset; i < M; i++) {
|
||||||
|
out[i] = accs[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -141,55 +167,165 @@ template <
|
|||||||
typename U,
|
typename U,
|
||||||
typename Op,
|
typename Op,
|
||||||
int NDIM,
|
int NDIM,
|
||||||
int BLOCK_DIM_X,
|
int BLOCK_DIM,
|
||||||
int N_READS = 4>
|
int N_READS = 4>
|
||||||
__global__ void row_reduce_looped(
|
__global__ void row_reduce_looped(
|
||||||
const T* in,
|
T* in,
|
||||||
U* out,
|
U* out,
|
||||||
size_t out_size,
|
size_t out_size,
|
||||||
const __grid_constant__ RowReduceArgs args) {
|
const __grid_constant__ RowReduceArgs args) {
|
||||||
auto grid = cg::this_grid();
|
auto grid = cg::this_grid();
|
||||||
auto block = cg::this_thread_block();
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
|
|
||||||
size_t out_idx = grid.thread_rank() / BLOCK_DIM_X;
|
size_t out_idx = grid.block_rank();
|
||||||
if (out_idx >= out_size) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
|
|
||||||
U total_val = ReduceInit<Op, T>::value();
|
U total[1];
|
||||||
|
U init = ReduceInit<Op, T>::value();
|
||||||
|
total[0] = init;
|
||||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
|
size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS);
|
||||||
|
size_t final_offset = full_blocks * BLOCK_DIM * N_READS;
|
||||||
|
|
||||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||||
|
|
||||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||||
for (size_t r = 0; r < cuda::ceil_div(args.row_size, BLOCK_DIM_X * N_READS);
|
for (size_t r = 0; r < full_blocks; r++) {
|
||||||
r++) {
|
T vals[N_READS];
|
||||||
U vals[N_READS];
|
cub::LoadDirectBlockedVectorized<T, N_READS>(
|
||||||
cub::LoadDirectBlocked(
|
block.thread_rank(),
|
||||||
r * BLOCK_DIM_X + block.thread_index().x,
|
in + loop.location() + r * BLOCK_DIM * N_READS,
|
||||||
make_cast_iterator<U>(in + loop.location()),
|
vals);
|
||||||
vals,
|
for (int i = 0; i < N_READS; i++) {
|
||||||
args.row_size,
|
total[0] = op(total[0], __cast<U, T>(vals[i]));
|
||||||
ReduceInit<Op, T>::value());
|
}
|
||||||
total_val = op(total_val, cub::ThreadReduce(vals, op));
|
|
||||||
}
|
}
|
||||||
|
if (final_offset < args.row_size) {
|
||||||
|
T vals[N_READS];
|
||||||
|
cub::LoadDirectBlocked(
|
||||||
|
block.thread_rank(),
|
||||||
|
in + loop.location() + final_offset,
|
||||||
|
vals,
|
||||||
|
args.row_size - final_offset,
|
||||||
|
__cast<T, U>(init));
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
total[0] = op(total[0], __cast<U, T>(vals[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// TODO: Maybe block.sync() here?
|
||||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef cub::BlockReduce<U, BLOCK_DIM_X> BlockReduceT;
|
__shared__ U shared_accumulators[32];
|
||||||
__shared__ typename BlockReduceT::TempStorage temp;
|
block_reduce(block, warp, total, shared_accumulators, op, init);
|
||||||
|
|
||||||
total_val = BlockReduceT(temp).Reduce(total_val, op);
|
|
||||||
|
|
||||||
if (block.thread_rank() == 0) {
|
if (block.thread_rank() == 0) {
|
||||||
out[out_idx] = total_val;
|
out[out_idx] = total[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cu
|
} // namespace cu
|
||||||
|
|
||||||
|
void row_reduce_simple(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan) {
|
||||||
|
constexpr int N_READS = 8;
|
||||||
|
|
||||||
|
// Allocate data for the output using in's layout to avoid elem_to_loc in the
|
||||||
|
// kernel.
|
||||||
|
allocate_same_layout(out, in, axes);
|
||||||
|
|
||||||
|
// TODO: If out.size() < 1024 which will be a common case then write this in
|
||||||
|
// 2 passes. Something like 32 * out.size() and then do a warp reduce.
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
using T = cuda_type_t<CTYPE>;
|
||||||
|
using U = cu::ReduceResult<OP, T>::type;
|
||||||
|
|
||||||
|
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||||
|
T* indata = const_cast<T*>(in.data<T>());
|
||||||
|
|
||||||
|
// Calculate the grid and block dims
|
||||||
|
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
|
||||||
|
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
|
int threads = std::min(1024UL, reductions);
|
||||||
|
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||||
|
dim3 block(threads, 1, 1);
|
||||||
|
|
||||||
|
// Pick the kernel
|
||||||
|
auto kernel = cu::row_reduce_simple<T, U, OP, N_READS>;
|
||||||
|
if (grid.x >= 1024) {
|
||||||
|
grid.x = (grid.x + 1) / 2;
|
||||||
|
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Launch
|
||||||
|
kernel<<<grid, block, 0, stream>>>(
|
||||||
|
indata, out.data<U>(), out.size(), plan.shape.back());
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void row_reduce_looped(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan,
|
||||||
|
cu::RowReduceArgs args) {
|
||||||
|
constexpr int N_READS = 8;
|
||||||
|
|
||||||
|
// Allocate data for the output using in's layout to access them as
|
||||||
|
// contiguously as possible.
|
||||||
|
allocate_same_layout(out, in, axes);
|
||||||
|
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
|
using T = cuda_type_t<CTYPE>;
|
||||||
|
using U = cu::ReduceResult<OP, T>::type;
|
||||||
|
|
||||||
|
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||||
|
T* indata = const_cast<T*>(in.data<T>());
|
||||||
|
|
||||||
|
// Calculate the grid and block dims
|
||||||
|
args.sort_access_pattern(in, axes);
|
||||||
|
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
|
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
|
||||||
|
int threads = std::min(1024UL, reductions);
|
||||||
|
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||||
|
dim3 block(threads, 1, 1);
|
||||||
|
|
||||||
|
// Pick the kernel
|
||||||
|
auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>;
|
||||||
|
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||||
|
MLX_SWITCH_BLOCK_DIM(threads, THREADS, {
|
||||||
|
kernel = cu::row_reduce_looped<T, U, OP, NDIM, THREADS, N_READS>;
|
||||||
|
block.x = THREADS;
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// Launch
|
||||||
|
kernel<<<grid, block, 0, stream>>>(
|
||||||
|
indata, out.data<U>(), out.size(), args);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void row_reduce(
|
void row_reduce(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
const array& in,
|
const array& in,
|
||||||
@ -197,54 +333,35 @@ void row_reduce(
|
|||||||
Reduce::ReduceType reduce_type,
|
Reduce::ReduceType reduce_type,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const ReductionPlan& plan) {
|
const ReductionPlan& plan) {
|
||||||
|
// Current row reduction options
|
||||||
|
//
|
||||||
|
// - row_reduce_simple
|
||||||
|
//
|
||||||
|
// That means that we are simply reducing across the fastest moving axis.
|
||||||
|
// We are reducing 1 or 2 rows per threadblock depending on the size of
|
||||||
|
// output.
|
||||||
|
//
|
||||||
|
// - row_reduce_looped
|
||||||
|
//
|
||||||
|
// It is a general row reduction. We are computing 1 output per
|
||||||
|
// threadblock. We read the fastest moving axis vectorized and loop over
|
||||||
|
// the rest of the axes.
|
||||||
|
//
|
||||||
|
// Notes: We opt to read as much in order as possible and leave
|
||||||
|
// transpositions as they are (contrary to our Metal backend).
|
||||||
|
|
||||||
|
// Simple row reduce means that we have 1 axis that we are reducing over and
|
||||||
|
// it has stride 1.
|
||||||
|
if (plan.shape.size() == 1) {
|
||||||
|
row_reduce_simple(encoder, in, out, reduce_type, axes, plan);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make the args struct to help route to the best kernel
|
||||||
cu::RowReduceArgs args(in, plan, axes);
|
cu::RowReduceArgs args(in, plan, axes);
|
||||||
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
// Fallback row reduce
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
row_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args));
|
||||||
using InType = cuda_type_t<CTYPE>;
|
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
|
||||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
|
||||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
|
||||||
constexpr size_t N_READS = 4;
|
|
||||||
dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides());
|
|
||||||
dim3 block_dims, num_blocks;
|
|
||||||
auto kernel =
|
|
||||||
cu::row_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
|
||||||
if (args.row_size <= 64) {
|
|
||||||
if ((args.non_row_reductions < 32 && args.row_size <= 8) ||
|
|
||||||
(args.non_row_reductions <= 8)) {
|
|
||||||
block_dims.x = std::min(out_dims.x, 1024u);
|
|
||||||
num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x);
|
|
||||||
num_blocks.y = out_dims.y;
|
|
||||||
} else {
|
|
||||||
block_dims.x = WARP_SIZE;
|
|
||||||
num_blocks.y = out_dims.x;
|
|
||||||
num_blocks.z = out_dims.y;
|
|
||||||
kernel =
|
|
||||||
cu::row_reduce_small_warp<InType, OutType, OP, NDIM, N_READS>;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
size_t num_threads = cuda::ceil_div(args.row_size, N_READS);
|
|
||||||
num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE;
|
|
||||||
MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, {
|
|
||||||
num_blocks.y = out_dims.x;
|
|
||||||
num_blocks.z = out_dims.y;
|
|
||||||
block_dims.x = BLOCK_DIM_X;
|
|
||||||
kernel = cu::row_reduce_looped<
|
|
||||||
InType,
|
|
||||||
OutType,
|
|
||||||
OP,
|
|
||||||
NDIM,
|
|
||||||
BLOCK_DIM_X,
|
|
||||||
N_READS>;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
|
||||||
in.data<InType>(), out.data<OutType>(), out.size(), args);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1,84 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
|
||||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
|
||||||
|
|
||||||
#include <thrust/device_ptr.h>
|
|
||||||
#include <cub/device/device_reduce.cuh>
|
|
||||||
#include <cub/device/device_segmented_reduce.cuh>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
template <typename... Args>
|
|
||||||
void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
|
||||||
// Allocate temporary storage.
|
|
||||||
size_t size;
|
|
||||||
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...));
|
|
||||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
|
||||||
encoder.add_temporary(temp);
|
|
||||||
// Run op.
|
|
||||||
CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data<void>(), size, args...));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename... Args>
|
|
||||||
void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) {
|
|
||||||
// Allocate temporary storage.
|
|
||||||
size_t size;
|
|
||||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...));
|
|
||||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
|
||||||
encoder.add_temporary(temp);
|
|
||||||
// Run op.
|
|
||||||
CHECK_CUDA_ERROR(
|
|
||||||
cub::DeviceSegmentedReduce::Reduce(temp.data<void>(), size, args...));
|
|
||||||
}
|
|
||||||
|
|
||||||
struct MultiplyOp {
|
|
||||||
int factor;
|
|
||||||
__device__ int operator()(int i) {
|
|
||||||
return i * factor;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void segmented_reduce(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
Reduce::ReduceType reduce_type,
|
|
||||||
const std::vector<int>& axes,
|
|
||||||
const ReductionPlan& plan) {
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
|
||||||
using InType = cuda_type_t<CTYPE>;
|
|
||||||
using OutType = cu::ReduceResult<OP, InType>::type;
|
|
||||||
auto in_iter = cu::make_cast_iterator<OutType>(
|
|
||||||
thrust::device_pointer_cast(in.data<InType>()));
|
|
||||||
auto out_ptr = thrust::device_pointer_cast(out.data<OutType>());
|
|
||||||
auto init = cu::ReduceInit<OP, InType>::value();
|
|
||||||
|
|
||||||
if (plan.type == ContiguousAllReduce) {
|
|
||||||
cub_all_reduce(
|
|
||||||
encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream);
|
|
||||||
} else if (plan.type == ContiguousReduce) {
|
|
||||||
auto offsets = thrust::make_transform_iterator(
|
|
||||||
thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()});
|
|
||||||
cub_segmented_reduce(
|
|
||||||
encoder,
|
|
||||||
in_iter,
|
|
||||||
out_ptr,
|
|
||||||
out.size(),
|
|
||||||
offsets,
|
|
||||||
offsets + 1,
|
|
||||||
OP(),
|
|
||||||
init,
|
|
||||||
stream);
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error("Unsupported plan in segmented_reduce.");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
@ -51,7 +51,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
|
|||||||
make_cast_iterator<AccT>(in),
|
make_cast_iterator<AccT>(in),
|
||||||
vals,
|
vals,
|
||||||
axis_size,
|
axis_size,
|
||||||
Limits<AccT>::finite_min());
|
Limits<AccT>::min());
|
||||||
prevmax = maxval;
|
prevmax = maxval;
|
||||||
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
|
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
|
||||||
// Online normalizer calculation for softmax:
|
// Online normalizer calculation for softmax:
|
||||||
@ -79,7 +79,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
|
|||||||
block.sync();
|
block.sync();
|
||||||
maxval = warp.thread_rank() < warp.meta_group_size()
|
maxval = warp.thread_rank() < warp.meta_group_size()
|
||||||
? local_max[warp.thread_rank()]
|
? local_max[warp.thread_rank()]
|
||||||
: Limits<AccT>::finite_min();
|
: Limits<AccT>::min();
|
||||||
maxval = cg::reduce(warp, maxval, max_op);
|
maxval = cg::reduce(warp, maxval, max_op);
|
||||||
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
normalizer = normalizer * softmax_exp(prevmax - maxval);
|
||||||
if (warp.thread_rank() == 0) {
|
if (warp.thread_rank() == 0) {
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
cuda_skip = {
|
cuda_skip = {
|
||||||
"TestArray.test_api",
|
"TestArray.test_api",
|
||||||
"TestBF16.test_arg_reduction_ops",
|
"TestBF16.test_arg_reduction_ops",
|
||||||
"TestBF16.test_reduction_ops",
|
|
||||||
"TestBlas.test_complex_gemm",
|
"TestBlas.test_complex_gemm",
|
||||||
"TestEinsum.test_ellipses",
|
"TestEinsum.test_ellipses",
|
||||||
"TestEinsum.test_opt_einsum_test_cases",
|
"TestEinsum.test_opt_einsum_test_cases",
|
||||||
@ -13,11 +12,7 @@ cuda_skip = {
|
|||||||
"TestLayers.test_upsample",
|
"TestLayers.test_upsample",
|
||||||
"TestOps.test_complex_ops",
|
"TestOps.test_complex_ops",
|
||||||
"TestOps.test_dynamic_slicing",
|
"TestOps.test_dynamic_slicing",
|
||||||
"TestOps.test_softmax",
|
|
||||||
"TestReduce.test_axis_permutation_sums",
|
|
||||||
"TestReduce.test_dtypes",
|
"TestReduce.test_dtypes",
|
||||||
"TestReduce.test_expand_sums",
|
|
||||||
"TestReduce.test_many_reduction_axes",
|
|
||||||
"TestUpsample.test_torch_upsample",
|
"TestUpsample.test_torch_upsample",
|
||||||
# Block masked matmul NYI
|
# Block masked matmul NYI
|
||||||
"TestBlas.test_block_masked_matmul",
|
"TestBlas.test_block_masked_matmul",
|
||||||
|
Loading…
Reference in New Issue
Block a user