mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Add comments and clean up
This commit is contained in:
parent
abdb21f27c
commit
664d8e42b8
@ -77,88 +77,6 @@ struct RowReduceArgs {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
|
||||||
//__global__ void row_reduce_small(
|
|
||||||
// const T* in,
|
|
||||||
// U* out,
|
|
||||||
// size_t out_size,
|
|
||||||
// const __grid_constant__ RowReduceArgs args) {
|
|
||||||
// size_t out_idx = cg::this_grid().thread_rank();
|
|
||||||
// if (out_idx >= out_size) {
|
|
||||||
// return;
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// Op op;
|
|
||||||
//
|
|
||||||
// U total_val = ReduceInit<Op, T>::value();
|
|
||||||
// LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
|
||||||
//
|
|
||||||
// in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(),
|
|
||||||
// args.ndim);
|
|
||||||
//
|
|
||||||
// for (size_t n = 0; n < args.non_row_reductions; n++) {
|
|
||||||
// for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
|
|
||||||
// U vals[N_READS];
|
|
||||||
// cub::LoadDirectBlocked(
|
|
||||||
// r,
|
|
||||||
// make_cast_iterator<U>(in + loop.location()),
|
|
||||||
// vals,
|
|
||||||
// args.row_size,
|
|
||||||
// ReduceInit<Op, T>::value());
|
|
||||||
// total_val = op(total_val, cub::ThreadReduce(vals, op));
|
|
||||||
// }
|
|
||||||
// loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// out[out_idx] = total_val;
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
|
||||||
//__global__ void row_reduce_small_warp(
|
|
||||||
// const T* in,
|
|
||||||
// U* out,
|
|
||||||
// size_t out_size,
|
|
||||||
// const __grid_constant__ RowReduceArgs args) {
|
|
||||||
// auto grid = cg::this_grid();
|
|
||||||
// auto block = cg::this_thread_block();
|
|
||||||
// auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
|
||||||
//
|
|
||||||
// size_t out_idx = grid.thread_rank() / WARP_SIZE;
|
|
||||||
// if (out_idx >= out_size) {
|
|
||||||
// return;
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// Op op;
|
|
||||||
//
|
|
||||||
// U total_val = ReduceInit<Op, T>::value();
|
|
||||||
// LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
|
||||||
//
|
|
||||||
// in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(),
|
|
||||||
// args.ndim);
|
|
||||||
//
|
|
||||||
// for (size_t n = warp.thread_rank(); n < args.non_row_reductions;
|
|
||||||
// n += WARP_SIZE) {
|
|
||||||
// for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) {
|
|
||||||
// U vals[N_READS];
|
|
||||||
// cub::LoadDirectBlocked(
|
|
||||||
// r,
|
|
||||||
// make_cast_iterator<U>(in + loop.location()),
|
|
||||||
// vals,
|
|
||||||
// args.row_size,
|
|
||||||
// ReduceInit<Op, T>::value());
|
|
||||||
// total_val = op(total_val, cub::ThreadReduce(vals, op));
|
|
||||||
// }
|
|
||||||
// loop.next(WARP_SIZE, args.reduce_shape.data(),
|
|
||||||
// args.reduce_strides.data());
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// total_val = cg::reduce(warp, total_val, op);
|
|
||||||
//
|
|
||||||
// if (warp.thread_rank() == 0) {
|
|
||||||
// out[out_idx] = total_val;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
|
template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
|
||||||
__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
||||||
auto grid = cg::this_grid();
|
auto grid = cg::this_grid();
|
||||||
@ -286,100 +204,8 @@ __global__ void row_reduce_looped(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int N = 4>
|
|
||||||
__global__ void reduce_initialize(U* out, size_t out_size) {
|
|
||||||
auto grid = cg::this_grid();
|
|
||||||
if (grid.thread_rank() * N + N <= out_size) {
|
|
||||||
for (int i = 0; i < N; i++) {
|
|
||||||
out[grid.thread_rank() * N + i] = ReduceInit<Op, T>::value();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int i = grid.thread_rank() * N; i < out_size; i++) {
|
|
||||||
out[i] = ReduceInit<Op, T>::value();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int BLOCK_DIM, int N_READS = 4>
|
|
||||||
__global__ void row_reduce_atomics(
|
|
||||||
T* in,
|
|
||||||
U* out,
|
|
||||||
size_t out_size,
|
|
||||||
const __grid_constant__ RowReduceArgs args) {
|
|
||||||
auto grid = cg::this_grid();
|
|
||||||
auto block = cg::this_thread_block();
|
|
||||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
|
||||||
|
|
||||||
size_t reduction_idx = grid.block_rank() / out_size;
|
|
||||||
size_t out_idx = grid.block_rank() % out_size;
|
|
||||||
|
|
||||||
Op op;
|
|
||||||
|
|
||||||
U total[1];
|
|
||||||
U init = ReduceInit<Op, T>::value();
|
|
||||||
total[0] = init;
|
|
||||||
size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS);
|
|
||||||
size_t final_offset = full_blocks * BLOCK_DIM * N_READS;
|
|
||||||
|
|
||||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
|
||||||
in += elem_to_loc(
|
|
||||||
reduction_idx,
|
|
||||||
args.reduce_shape.data(),
|
|
||||||
args.reduce_strides.data(),
|
|
||||||
args.reduce_ndim);
|
|
||||||
|
|
||||||
for (size_t r = 0; r < full_blocks; r++) {
|
|
||||||
T vals[N_READS];
|
|
||||||
cub::LoadDirectBlockedVectorized<T, N_READS>(
|
|
||||||
block.thread_rank(), in + r * BLOCK_DIM * N_READS, vals);
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
total[0] = op(total[0], __cast<U, T>(vals[i]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (final_offset < args.row_size) {
|
|
||||||
T vals[N_READS];
|
|
||||||
cub::LoadDirectBlocked(
|
|
||||||
block.thread_rank(),
|
|
||||||
in + final_offset,
|
|
||||||
vals,
|
|
||||||
args.row_size - final_offset,
|
|
||||||
__cast<T, U>(init));
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
total[0] = op(total[0], __cast<U, T>(vals[i]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__shared__ U shared_accumulators[32];
|
|
||||||
block_reduce(block, warp, total, shared_accumulators, op, init);
|
|
||||||
|
|
||||||
if (block.thread_rank() == 0) {
|
|
||||||
op.atomic_update(out + out_idx, total[0]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace cu
|
} // namespace cu
|
||||||
|
|
||||||
void reduce_initialize(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
array& out,
|
|
||||||
Reduce::ReduceType reduce_type) {
|
|
||||||
constexpr int N_WRITES = 8;
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
|
||||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE, {
|
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
|
||||||
using T = cuda_type_t<CTYPE>;
|
|
||||||
using U = cu::ReduceResult<OP, T>::type;
|
|
||||||
|
|
||||||
auto kernel = cu::reduce_initialize<T, U, OP, N_WRITES>;
|
|
||||||
auto [grid, block] =
|
|
||||||
get_launch_args(kernel, out, out.size() >= 1UL << 31, N_WRITES);
|
|
||||||
kernel<<<grid, block, 0, stream>>>(out.data<U>(), out.size());
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
void row_reduce_simple(
|
void row_reduce_simple(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
const array& in,
|
const array& in,
|
||||||
@ -480,66 +306,6 @@ void row_reduce_looped(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void row_reduce_atomics(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
Reduce::ReduceType reduce_type,
|
|
||||||
const std::vector<int>& axes,
|
|
||||||
const ReductionPlan& plan,
|
|
||||||
cu::RowReduceArgs args) {
|
|
||||||
constexpr int N_READS = 8;
|
|
||||||
|
|
||||||
// Allocate data for the output using in's layout to access them as
|
|
||||||
// contiguously as possible.
|
|
||||||
allocate_same_layout(out, in, axes);
|
|
||||||
|
|
||||||
// Just a way to get out of the constness because cub doesn't like it ...
|
|
||||||
// (sigh)
|
|
||||||
array x = in;
|
|
||||||
|
|
||||||
// Initialize
|
|
||||||
reduce_initialize(encoder, out, reduce_type);
|
|
||||||
|
|
||||||
// Launch the reduction
|
|
||||||
encoder.set_input_array(x);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
|
||||||
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
|
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
|
||||||
using T = cuda_type_t<CTYPE>;
|
|
||||||
using U = cu::ReduceResult<OP, T>::type;
|
|
||||||
|
|
||||||
args.convert_shapes_to_contiguous(x, axes);
|
|
||||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
|
||||||
if (grid.x * args.non_row_reductions < INT_MAX) {
|
|
||||||
grid.x *= args.non_row_reductions;
|
|
||||||
} else if (grid.y * args.non_row_reductions < 65536) {
|
|
||||||
grid.y *= args.non_row_reductions;
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[row_reduce_atomics] Non-row reductions need to be factorized which is NYI");
|
|
||||||
}
|
|
||||||
size_t reductions = args.row_size / N_READS;
|
|
||||||
int threads = std::min(1024UL, reductions);
|
|
||||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
|
||||||
dim3 block(threads, 1, 1);
|
|
||||||
|
|
||||||
// Pick the kernel
|
|
||||||
auto kernel = cu::row_reduce_atomics<T, U, OP, 32, N_READS>;
|
|
||||||
MLX_SWITCH_BLOCK_DIM(threads, THREADS, {
|
|
||||||
kernel = cu::row_reduce_atomics<T, U, OP, THREADS, N_READS>;
|
|
||||||
block.x = THREADS;
|
|
||||||
});
|
|
||||||
|
|
||||||
// Launch
|
|
||||||
kernel<<<grid, block, 0, stream>>>(
|
|
||||||
x.data<T>(), out.data<U>(), out.size(), args);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
void row_reduce(
|
void row_reduce(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
const array& in,
|
const array& in,
|
||||||
@ -547,6 +313,23 @@ void row_reduce(
|
|||||||
Reduce::ReduceType reduce_type,
|
Reduce::ReduceType reduce_type,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const ReductionPlan& plan) {
|
const ReductionPlan& plan) {
|
||||||
|
// Current row reduction options
|
||||||
|
//
|
||||||
|
// - row_reduce_simple
|
||||||
|
//
|
||||||
|
// That means that we are simply reducing across the fastest moving axis.
|
||||||
|
// We are reducing 1 or 2 rows per threadblock depending on the size of
|
||||||
|
// output.
|
||||||
|
//
|
||||||
|
// - row_reduce_looped
|
||||||
|
//
|
||||||
|
// It is a general row reduction. We are computing 1 output per
|
||||||
|
// threadblock. We read the fastest moving axis vectorized and loop over
|
||||||
|
// the rest of the axes.
|
||||||
|
//
|
||||||
|
// Notes: We opt to read as much in order as possible and leave
|
||||||
|
// transpositions as they are (contrary to our Metal backend).
|
||||||
|
|
||||||
// Simple row reduce means that we have 1 axis that we are reducing over and
|
// Simple row reduce means that we have 1 axis that we are reducing over and
|
||||||
// it has stride 1.
|
// it has stride 1.
|
||||||
if (plan.shape.size() == 1) {
|
if (plan.shape.size() == 1) {
|
||||||
@ -557,62 +340,8 @@ void row_reduce(
|
|||||||
// Make the args struct to help route to the best kernel
|
// Make the args struct to help route to the best kernel
|
||||||
cu::RowReduceArgs args(in, plan, axes);
|
cu::RowReduceArgs args(in, plan, axes);
|
||||||
|
|
||||||
// Let's use atomics to increase parallelism
|
|
||||||
if (false && args.row_size < 512) {
|
|
||||||
row_reduce_atomics(
|
|
||||||
encoder, in, out, reduce_type, axes, plan, std::move(args));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback row reduce
|
// Fallback row reduce
|
||||||
row_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args));
|
row_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args));
|
||||||
|
|
||||||
// encoder.launch_kernel([&](cudaStream_t stream) {
|
|
||||||
// MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
|
||||||
// using InType = cuda_type_t<CTYPE>;
|
|
||||||
// MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
|
||||||
// using OutType = cu::ReduceResult<OP, InType>::type;
|
|
||||||
// MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
|
||||||
// constexpr size_t N_READS = 4;
|
|
||||||
// dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides());
|
|
||||||
// dim3 block_dims, num_blocks;
|
|
||||||
// auto kernel =
|
|
||||||
// cu::row_reduce_small<InType, OutType, OP, NDIM, N_READS>;
|
|
||||||
// if (args.row_size <= 64) {
|
|
||||||
// if ((args.non_row_reductions < 32 && args.row_size <= 8) ||
|
|
||||||
// (args.non_row_reductions <= 8)) {
|
|
||||||
// block_dims.x = std::min(out_dims.x, 1024u);
|
|
||||||
// num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x);
|
|
||||||
// num_blocks.y = out_dims.y;
|
|
||||||
// } else {
|
|
||||||
// block_dims.x = WARP_SIZE;
|
|
||||||
// num_blocks.y = out_dims.x;
|
|
||||||
// num_blocks.z = out_dims.y;
|
|
||||||
// kernel =
|
|
||||||
// cu::row_reduce_small_warp<InType, OutType, OP, NDIM,
|
|
||||||
// N_READS>;
|
|
||||||
// }
|
|
||||||
// } else {
|
|
||||||
// size_t num_threads = cuda::ceil_div(args.row_size, N_READS);
|
|
||||||
// num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE;
|
|
||||||
// MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, {
|
|
||||||
// num_blocks.y = out_dims.x;
|
|
||||||
// num_blocks.z = out_dims.y;
|
|
||||||
// block_dims.x = BLOCK_DIM_X;
|
|
||||||
// kernel = cu::row_reduce_looped<
|
|
||||||
// InType,
|
|
||||||
// OutType,
|
|
||||||
// OP,
|
|
||||||
// NDIM,
|
|
||||||
// BLOCK_DIM_X,
|
|
||||||
// N_READS>;
|
|
||||||
// });
|
|
||||||
// }
|
|
||||||
// kernel<<<num_blocks, block_dims, 0, stream>>>(
|
|
||||||
// in.data<InType>(), out.data<OutType>(), out.size(), args);
|
|
||||||
// });
|
|
||||||
// });
|
|
||||||
// });
|
|
||||||
// });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
Loading…
Reference in New Issue
Block a user