mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Comments
This commit is contained in:
@@ -96,43 +96,52 @@ void all_reduce(
|
|||||||
|
|
||||||
int blocks, threads;
|
int blocks, threads;
|
||||||
size_t block_step;
|
size_t block_step;
|
||||||
array x = in;
|
size_t insize = in.size();
|
||||||
|
Dtype dt = in.dtype();
|
||||||
|
|
||||||
|
// Cub doesn't like const pointers for load (sigh).
|
||||||
|
void* indata = const_cast<void*>(in.data<void>());
|
||||||
|
|
||||||
// Large array so allocate an intermediate and accumulate there
|
// Large array so allocate an intermediate and accumulate there
|
||||||
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
|
||||||
|
encoder.set_input_array(in);
|
||||||
if (blocks > 1) {
|
if (blocks > 1) {
|
||||||
array intermediate({blocks}, out.dtype(), nullptr, {});
|
array intermediate({blocks}, out.dtype(), nullptr, {});
|
||||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||||
encoder.add_temporary(intermediate);
|
encoder.add_temporary(intermediate);
|
||||||
encoder.set_input_array(x);
|
|
||||||
encoder.set_output_array(intermediate);
|
encoder.set_output_array(intermediate);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
MLX_SWITCH_ALL_TYPES(dt, CTYPE, {
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
using T = cuda_type_t<CTYPE>;
|
using T = cuda_type_t<CTYPE>;
|
||||||
using U = cu::ReduceResult<OP, T>::type;
|
using U = cu::ReduceResult<OP, T>::type;
|
||||||
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
||||||
kernel<<<blocks, threads, 0, stream>>>(
|
kernel<<<blocks, threads, 0, stream>>>(
|
||||||
x.data<T>(), intermediate.data<U>(), block_step, x.size());
|
static_cast<T*>(indata),
|
||||||
|
intermediate.data<U>(),
|
||||||
|
block_step,
|
||||||
|
insize);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
// Set the input for the next step and recalculate the blocks
|
// Set the input for the next step and recalculate the blocks
|
||||||
x = intermediate;
|
indata = intermediate.data<void>();
|
||||||
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
|
dt = intermediate.dtype();
|
||||||
|
insize = intermediate.size();
|
||||||
|
std::tie(blocks, threads, block_step) = get_args(insize, N_READS);
|
||||||
|
encoder.set_input_array(intermediate);
|
||||||
}
|
}
|
||||||
|
|
||||||
encoder.set_input_array(x);
|
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
MLX_SWITCH_ALL_TYPES(dt, CTYPE, {
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
using T = cuda_type_t<CTYPE>;
|
using T = cuda_type_t<CTYPE>;
|
||||||
using U = cu::ReduceResult<OP, T>::type;
|
using U = cu::ReduceResult<OP, T>::type;
|
||||||
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
||||||
kernel<<<blocks, threads, 0, stream>>>(
|
kernel<<<blocks, threads, 0, stream>>>(
|
||||||
x.data<T>(), out.data<U>(), block_step, x.size());
|
static_cast<T*>(indata), out.data<U>(), block_step, insize);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -38,12 +38,18 @@ struct ColReduceArgs {
|
|||||||
const array& in,
|
const array& in,
|
||||||
const ReductionPlan& plan,
|
const ReductionPlan& plan,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
|
using ShapeVector = decltype(plan.shape);
|
||||||
|
using StridesVector = decltype(plan.strides);
|
||||||
|
|
||||||
|
ShapeVector shape_vec;
|
||||||
|
StridesVector strides_vec;
|
||||||
|
|
||||||
assert(!plan.shape.empty());
|
assert(!plan.shape.empty());
|
||||||
reduction_size = plan.shape.back();
|
reduction_size = plan.shape.back();
|
||||||
reduction_stride = plan.strides.back();
|
reduction_stride = plan.strides.back();
|
||||||
|
|
||||||
int64_t stride_back = 1;
|
int64_t stride_back = 1;
|
||||||
auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
|
std::tie(shape_vec, strides_vec) = shapes_without_reduction_axes(in, axes);
|
||||||
while (!shape_vec.empty() && stride_back < reduction_stride) {
|
while (!shape_vec.empty() && stride_back < reduction_stride) {
|
||||||
stride_back *= shape_vec.back();
|
stride_back *= shape_vec.back();
|
||||||
shape_vec.pop_back();
|
shape_vec.pop_back();
|
||||||
@@ -54,8 +60,8 @@ struct ColReduceArgs {
|
|||||||
std::sort(indices.begin(), indices.end(), [&](int left, int right) {
|
std::sort(indices.begin(), indices.end(), [&](int left, int right) {
|
||||||
return strides_vec[left] > strides_vec[right];
|
return strides_vec[left] > strides_vec[right];
|
||||||
});
|
});
|
||||||
decltype(shape_vec) sorted_shape;
|
ShapeVector sorted_shape;
|
||||||
decltype(strides_vec) sorted_strides;
|
StridesVector sorted_strides;
|
||||||
for (auto idx : indices) {
|
for (auto idx : indices) {
|
||||||
sorted_shape.push_back(shape_vec[idx]);
|
sorted_shape.push_back(shape_vec[idx]);
|
||||||
sorted_strides.push_back(strides_vec[idx]);
|
sorted_strides.push_back(strides_vec[idx]);
|
||||||
@@ -206,26 +212,25 @@ void col_reduce_looped(
|
|||||||
// contiguously as possible.
|
// contiguously as possible.
|
||||||
allocate_same_layout(out, in, axes);
|
allocate_same_layout(out, in, axes);
|
||||||
|
|
||||||
// Just a way to get out of the constness because cub doesn't like it ...
|
encoder.set_input_array(in);
|
||||||
// (sigh)
|
|
||||||
array x = in;
|
|
||||||
|
|
||||||
encoder.set_input_array(x);
|
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
||||||
using T = cuda_type_t<CTYPE>;
|
using T = cuda_type_t<CTYPE>;
|
||||||
using U = cu::ReduceResult<OP, T>::type;
|
using U = cu::ReduceResult<OP, T>::type;
|
||||||
|
|
||||||
|
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||||
|
T* indata = const_cast<T*>(in.data<T>());
|
||||||
|
|
||||||
constexpr int N_READS = 4;
|
constexpr int N_READS = 4;
|
||||||
constexpr int BM = 32;
|
constexpr int BM = 32;
|
||||||
constexpr int BN = 32;
|
constexpr int BN = 32;
|
||||||
dim3 grid = output_grid_for_col_reduce(out, args, BN);
|
dim3 grid = output_grid_for_col_reduce(out, args, BN);
|
||||||
int blocks = BM * BN / N_READS;
|
int blocks = BM * BN / N_READS;
|
||||||
auto kernel = cu::col_reduce_looped<T, U, OP, NDIM, BM, BN, N_READS>;
|
auto kernel = cu::col_reduce_looped<T, U, OP, NDIM, BM, BN, N_READS>;
|
||||||
kernel<<<grid, blocks, 0, stream>>>(x.data<T>(), out.data<U>(), args);
|
kernel<<<grid, blocks, 0, stream>>>(indata, out.data<U>(), args);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -247,8 +252,7 @@ void col_reduce(
|
|||||||
// a subrow of the fast moving axis. For instance 32 elements.
|
// a subrow of the fast moving axis. For instance 32 elements.
|
||||||
//
|
//
|
||||||
// Notes: As in row reduce we opt to read as much in order as possible and
|
// Notes: As in row reduce we opt to read as much in order as possible and
|
||||||
// leave
|
// leave transpositions as they are (contrary to our Metal backend).
|
||||||
// transpositions as they are (contrary to our Metal backend).
|
|
||||||
//
|
//
|
||||||
// Moreover we need different kernels for short rows and tuning
|
// Moreover we need different kernels for short rows and tuning
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ void init_reduce(
|
|||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
}
|
}
|
||||||
|
|
||||||
encoder.set_input_array(in);
|
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
|
|||||||
@@ -146,7 +146,7 @@ inline void allocate_same_layout(
|
|||||||
auto fl = in.flags();
|
auto fl = in.flags();
|
||||||
fl.row_contiguous = rc;
|
fl.row_contiguous = rc;
|
||||||
fl.col_contiguous = cc;
|
fl.col_contiguous = cc;
|
||||||
fl.contiguous = data_size == out.size();
|
fl.contiguous = true;
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc(out.nbytes()),
|
allocator::malloc(out.nbytes()),
|
||||||
data_size,
|
data_size,
|
||||||
|
|||||||
@@ -241,20 +241,19 @@ void row_reduce_simple(
|
|||||||
// kernel.
|
// kernel.
|
||||||
allocate_same_layout(out, in, axes);
|
allocate_same_layout(out, in, axes);
|
||||||
|
|
||||||
// Just a way to get out of the constness because cub doesn't like it ...
|
|
||||||
// (sigh)
|
|
||||||
array x = in;
|
|
||||||
|
|
||||||
// TODO: If out.size() < 1024 which will be a common case then write this in
|
// TODO: If out.size() < 1024 which will be a common case then write this in
|
||||||
// 2 passes. Something like 32 * out.size() and then do a warp reduce.
|
// 2 passes. Something like 32 * out.size() and then do a warp reduce.
|
||||||
encoder.set_input_array(x);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
using T = cuda_type_t<CTYPE>;
|
using T = cuda_type_t<CTYPE>;
|
||||||
using U = cu::ReduceResult<OP, T>::type;
|
using U = cu::ReduceResult<OP, T>::type;
|
||||||
|
|
||||||
|
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||||
|
T* indata = const_cast<T*>(in.data<T>());
|
||||||
|
|
||||||
// Calculate the grid and block dims
|
// Calculate the grid and block dims
|
||||||
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
|
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
|
||||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
@@ -271,7 +270,7 @@ void row_reduce_simple(
|
|||||||
|
|
||||||
// Launch
|
// Launch
|
||||||
kernel<<<grid, block, 0, stream>>>(
|
kernel<<<grid, block, 0, stream>>>(
|
||||||
x.data<T>(), out.data<U>(), out.size(), plan.shape.back());
|
indata, out.data<U>(), out.size(), plan.shape.back());
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -291,20 +290,19 @@ void row_reduce_looped(
|
|||||||
// contiguously as possible.
|
// contiguously as possible.
|
||||||
allocate_same_layout(out, in, axes);
|
allocate_same_layout(out, in, axes);
|
||||||
|
|
||||||
// Just a way to get out of the constness because cub doesn't like it ...
|
encoder.set_input_array(in);
|
||||||
// (sigh)
|
|
||||||
array x = in;
|
|
||||||
|
|
||||||
encoder.set_input_array(x);
|
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
||||||
using T = cuda_type_t<CTYPE>;
|
using T = cuda_type_t<CTYPE>;
|
||||||
using U = cu::ReduceResult<OP, T>::type;
|
using U = cu::ReduceResult<OP, T>::type;
|
||||||
|
|
||||||
|
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||||
|
T* indata = const_cast<T*>(in.data<T>());
|
||||||
|
|
||||||
// Calculate the grid and block dims
|
// Calculate the grid and block dims
|
||||||
args.sort_access_pattern(x, axes);
|
args.sort_access_pattern(in, axes);
|
||||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
|
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
|
||||||
int threads = std::min(1024UL, reductions);
|
int threads = std::min(1024UL, reductions);
|
||||||
@@ -322,7 +320,7 @@ void row_reduce_looped(
|
|||||||
|
|
||||||
// Launch
|
// Launch
|
||||||
kernel<<<grid, block, 0, stream>>>(
|
kernel<<<grid, block, 0, stream>>>(
|
||||||
x.data<T>(), out.data<U>(), out.size(), args);
|
indata, out.data<U>(), out.size(), args);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
Reference in New Issue
Block a user