mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Strided scan
This commit is contained in:
@@ -39,37 +39,37 @@ struct ReduceInit<LogAddExp, T> {
|
|||||||
|
|
||||||
template <bool reverse, typename T, typename U, int N_READS>
|
template <bool reverse, typename T, typename U, int N_READS>
|
||||||
inline __device__ void
|
inline __device__ void
|
||||||
load_vals(int index, const T* in, U (&vals)[N_READS], int size, U init) {
|
load_values(int index, const T* in, U (&values)[N_READS], int size, U init) {
|
||||||
int remaining = size - index * N_READS;
|
int remaining = size - index * N_READS;
|
||||||
if constexpr (reverse) {
|
if constexpr (reverse) {
|
||||||
in += remaining - N_READS;
|
in += remaining - N_READS;
|
||||||
if (remaining < N_READS) {
|
if (remaining < N_READS) {
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
vals[N_READS - i - 1] =
|
values[N_READS - i - 1] =
|
||||||
(N_READS - i - 1 < remaining) ? cast_to<U>(in[i]) : init;
|
(N_READS - i - 1 < remaining) ? cast_to<U>(in[i]) : init;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
vals[N_READS - i - 1] = cast_to<U>(in[i]);
|
values[N_READS - i - 1] = cast_to<U>(in[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
in += index * N_READS;
|
in += index * N_READS;
|
||||||
if (remaining < N_READS) {
|
if (remaining < N_READS) {
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
vals[i] = (i < remaining) ? cast_to<U>(in[i]) : init;
|
values[i] = (i < remaining) ? cast_to<U>(in[i]) : init;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
vals[i] = cast_to<U>(in[i]);
|
values[i] = cast_to<U>(in[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool reverse, typename T, int N_READS>
|
template <bool reverse, int offset, typename T, int N_READS>
|
||||||
inline __device__ void
|
inline __device__ void
|
||||||
store_vals(int index, T* out, T (&vals)[N_READS], int size, int offset = 0) {
|
store_values(int index, T* out, T (&values)[N_READS], int size) {
|
||||||
int start = index * N_READS + offset;
|
int start = index * N_READS + offset;
|
||||||
int remaining = size - start;
|
int remaining = size - start;
|
||||||
if constexpr (reverse) {
|
if constexpr (reverse) {
|
||||||
@@ -77,12 +77,12 @@ store_vals(int index, T* out, T (&vals)[N_READS], int size, int offset = 0) {
|
|||||||
if (remaining < N_READS) {
|
if (remaining < N_READS) {
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
if (N_READS - i - 1 < remaining) {
|
if (N_READS - i - 1 < remaining) {
|
||||||
out[i] = vals[N_READS - i - 1];
|
out[i] = values[N_READS - i - 1];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
out[i] = vals[N_READS - i - 1];
|
out[i] = values[N_READS - i - 1];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -90,12 +90,12 @@ store_vals(int index, T* out, T (&vals)[N_READS], int size, int offset = 0) {
|
|||||||
if (remaining < N_READS) {
|
if (remaining < N_READS) {
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
if (i < remaining) {
|
if (i < remaining) {
|
||||||
out[i] = vals[i];
|
out[i] = values[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
out[i] = vals[i];
|
out[i] = values[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -125,24 +125,24 @@ __global__ void contiguous_scan(const T* in, U* out, int32_t axis_size) {
|
|||||||
// Scan per block.
|
// Scan per block.
|
||||||
for (int r = 0; r < cuda::ceil_div(axis_size, block.size() * N_READS); ++r) {
|
for (int r = 0; r < cuda::ceil_div(axis_size, block.size() * N_READS); ++r) {
|
||||||
int32_t index = r * block.size() + block.thread_rank();
|
int32_t index = r * block.size() + block.thread_rank();
|
||||||
U vals[N_READS];
|
U values[N_READS];
|
||||||
load_vals<reverse>(index, in, vals, axis_size, init);
|
load_values<reverse>(index, in, values, axis_size, init);
|
||||||
|
|
||||||
// Compute an inclusive scan per thread.
|
// Compute an inclusive scan per thread.
|
||||||
for (int i = 1; i < N_READS; i++) {
|
for (int i = 1; i < N_READS; ++i) {
|
||||||
vals[i] = op(vals[i], vals[i - 1]);
|
values[i] = op(values[i], values[i - 1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute exclusive scan of thread sums.
|
// Compute exclusive scan of thread sums.
|
||||||
U prev_thread_sum = cg::exclusive_scan(warp, vals[N_READS - 1], op);
|
U prev_thread_sum = cg::exclusive_scan(warp, values[N_READS - 1], op);
|
||||||
if (warp.thread_rank() == 0) {
|
if (warp.thread_rank() == 0) {
|
||||||
prev_thread_sum = init;
|
prev_thread_sum = init;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write wrap's sum to shared memory.
|
// Write wrap's sum to shared memory.
|
||||||
if (warp.thread_rank() == warp.size() - 1) {
|
if (warp.thread_rank() == WARP_SIZE - 1) {
|
||||||
warp_sums[warp.meta_group_rank()] =
|
warp_sums[warp.meta_group_rank()] =
|
||||||
op(prev_thread_sum, vals[N_READS - 1]);
|
op(prev_thread_sum, values[N_READS - 1]);
|
||||||
}
|
}
|
||||||
block.sync();
|
block.sync();
|
||||||
|
|
||||||
@@ -159,16 +159,16 @@ __global__ void contiguous_scan(const T* in, U* out, int32_t axis_size) {
|
|||||||
|
|
||||||
// Compute the output.
|
// Compute the output.
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
vals[i] = op(vals[i], prefix);
|
values[i] = op(values[i], prefix);
|
||||||
vals[i] = op(vals[i], warp_sums[warp.meta_group_rank()]);
|
values[i] = op(values[i], warp_sums[warp.meta_group_rank()]);
|
||||||
vals[i] = op(vals[i], prev_thread_sum);
|
values[i] = op(values[i], prev_thread_sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write the values.
|
// Write the values.
|
||||||
if (inclusive) {
|
if (inclusive) {
|
||||||
store_vals<reverse>(index, out, vals, axis_size);
|
store_values<reverse, 0>(index, out, values, axis_size);
|
||||||
} else {
|
} else {
|
||||||
store_vals<reverse>(index, out, vals, axis_size, 1);
|
store_values<reverse, 1>(index, out, values, axis_size);
|
||||||
if (reverse) {
|
if (reverse) {
|
||||||
if (block.thread_rank() == 0 && index == 0) {
|
if (block.thread_rank() == 0 && index == 0) {
|
||||||
out[axis_size - 1] = init;
|
out[axis_size - 1] = init;
|
||||||
@@ -183,14 +183,141 @@ __global__ void contiguous_scan(const T* in, U* out, int32_t axis_size) {
|
|||||||
|
|
||||||
// Share the prefix.
|
// Share the prefix.
|
||||||
if ((warp.meta_group_rank() == warp.meta_group_size() - 1) &&
|
if ((warp.meta_group_rank() == warp.meta_group_size() - 1) &&
|
||||||
(warp.thread_rank() == warp.size() - 1)) {
|
(warp.thread_rank() == WARP_SIZE - 1)) {
|
||||||
warp_sums[0] = vals[N_READS - 1];
|
warp_sums[0] = values[N_READS - 1];
|
||||||
}
|
}
|
||||||
block.sync();
|
block.sync();
|
||||||
prefix = warp_sums[0];
|
prefix = warp_sums[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename U,
|
||||||
|
typename Op,
|
||||||
|
int N_READS,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
bool inclusive,
|
||||||
|
bool reverse>
|
||||||
|
__global__ void strided_scan(
|
||||||
|
const T* in,
|
||||||
|
U* out,
|
||||||
|
int32_t axis_size,
|
||||||
|
int64_t stride,
|
||||||
|
int64_t stride_blocks) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
|
|
||||||
|
constexpr int BN_pad = WARP_SIZE + 16 / sizeof(U);
|
||||||
|
constexpr int n_warps = BN / N_READS;
|
||||||
|
constexpr int n_scans = BN / n_warps;
|
||||||
|
|
||||||
|
__shared__ U read_buffer[BM * BN_pad];
|
||||||
|
|
||||||
|
Op op;
|
||||||
|
U init = ReduceInit<Op, T>::value();
|
||||||
|
U values[n_scans];
|
||||||
|
U prefix[n_scans];
|
||||||
|
for (int i = 0; i < n_scans; ++i) {
|
||||||
|
prefix[i] = init;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute offsets.
|
||||||
|
int64_t offset = (grid.block_rank() / stride_blocks) * axis_size * stride;
|
||||||
|
int64_t global_index_x = (grid.block_rank() % stride_blocks) * BN;
|
||||||
|
uint read_offset_y = (block.thread_rank() * N_READS) / BN;
|
||||||
|
uint read_offset_x = (block.thread_rank() * N_READS) % BN;
|
||||||
|
uint scan_offset_y = warp.thread_rank();
|
||||||
|
uint scan_offset_x = warp.meta_group_rank() * n_scans;
|
||||||
|
|
||||||
|
uint stride_limit = stride - global_index_x;
|
||||||
|
in += offset + global_index_x + read_offset_x;
|
||||||
|
out += offset + global_index_x + read_offset_x;
|
||||||
|
U* read_into = read_buffer + read_offset_y * BN_pad + read_offset_x;
|
||||||
|
U* read_from = read_buffer + scan_offset_y * BN_pad + scan_offset_x;
|
||||||
|
|
||||||
|
for (uint j = 0; j < axis_size; j += BM) {
|
||||||
|
// Calculate the indices for the current thread.
|
||||||
|
uint index_y = j + read_offset_y;
|
||||||
|
uint check_index_y = index_y;
|
||||||
|
if (reverse) {
|
||||||
|
index_y = axis_size - 1 - index_y;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read in SM.
|
||||||
|
if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
read_into[i] = in[index_y * stride + i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {
|
||||||
|
read_into[i] = in[index_y * stride + i];
|
||||||
|
} else {
|
||||||
|
read_into[i] = init;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
|
||||||
|
// Read strided into registers.
|
||||||
|
for (int i = 0; i < n_scans; ++i) {
|
||||||
|
values[i] = read_from[i];
|
||||||
|
}
|
||||||
|
warp.sync();
|
||||||
|
|
||||||
|
// Perform the scan.
|
||||||
|
for (int i = 0; i < n_scans; ++i) {
|
||||||
|
values[i] = cg::inclusive_scan(warp, values[i], op);
|
||||||
|
values[i] = op(values[i], prefix[i]);
|
||||||
|
prefix[i] = warp.shfl(values[i], WARP_SIZE - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write to SM.
|
||||||
|
for (int i = 0; i < n_scans; ++i) {
|
||||||
|
read_from[i] = values[i];
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
|
||||||
|
// Write to device memory.
|
||||||
|
if (!inclusive) {
|
||||||
|
if (check_index_y == 0) {
|
||||||
|
if ((read_offset_x + N_READS) < stride_limit) {
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
out[index_y * stride + i] = init;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
if ((read_offset_x + i) < stride_limit) {
|
||||||
|
out[index_y * stride + i] = init;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (reverse) {
|
||||||
|
index_y -= 1;
|
||||||
|
check_index_y += 1;
|
||||||
|
} else {
|
||||||
|
index_y += 1;
|
||||||
|
check_index_y += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
out[index_y * stride + i] = read_into[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {
|
||||||
|
out[index_y * stride + i] = read_into[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace cu
|
} // namespace cu
|
||||||
|
|
||||||
template <typename F>
|
template <typename F>
|
||||||
@@ -259,6 +386,8 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
constexpr int N_READS = 4;
|
||||||
|
int32_t axis_size = in.shape(axis_);
|
||||||
bool contiguous = in.strides()[axis_] == 1;
|
bool contiguous = in.strides()[axis_] == 1;
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
@@ -274,7 +403,6 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
dispatch_bool(inclusive_, [&](auto inclusive) {
|
dispatch_bool(inclusive_, [&](auto inclusive) {
|
||||||
dispatch_bool(reverse_, [&](auto reverse) {
|
dispatch_bool(reverse_, [&](auto reverse) {
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
constexpr int N_READS = 4;
|
|
||||||
auto kernel = cu::contiguous_scan<
|
auto kernel = cu::contiguous_scan<
|
||||||
T,
|
T,
|
||||||
U,
|
U,
|
||||||
@@ -282,7 +410,6 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
N_READS,
|
N_READS,
|
||||||
inclusive.value,
|
inclusive.value,
|
||||||
reverse.value>;
|
reverse.value>;
|
||||||
int32_t axis_size = in.shape(axis_);
|
|
||||||
int block_dim = cuda::ceil_div(axis_size, N_READS);
|
int block_dim = cuda::ceil_div(axis_size, N_READS);
|
||||||
block_dim = cuda::ceil_div(block_dim, WARP_SIZE) * WARP_SIZE;
|
block_dim = cuda::ceil_div(block_dim, WARP_SIZE) * WARP_SIZE;
|
||||||
block_dim = std::min(block_dim, WARP_SIZE * WARP_SIZE);
|
block_dim = std::min(block_dim, WARP_SIZE * WARP_SIZE);
|
||||||
@@ -294,7 +421,32 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
out.data<U>(),
|
out.data<U>(),
|
||||||
axis_size);
|
axis_size);
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("Strided Scan NYI");
|
constexpr int BM = WARP_SIZE;
|
||||||
|
constexpr int BN = WARP_SIZE;
|
||||||
|
auto kernel = cu::strided_scan<
|
||||||
|
T,
|
||||||
|
U,
|
||||||
|
Op,
|
||||||
|
N_READS,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
inclusive.value,
|
||||||
|
reverse.value>;
|
||||||
|
int64_t stride = in.strides()[axis_];
|
||||||
|
int64_t stride_blocks = cuda::ceil_div(stride, BN);
|
||||||
|
dim3 num_blocks = get_2d_grid_dims(
|
||||||
|
in.shape(), in.strides(), axis_size * stride);
|
||||||
|
num_blocks.x *= stride_blocks;
|
||||||
|
int block_dim = BN / N_READS * WARP_SIZE;
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
num_blocks,
|
||||||
|
block_dim,
|
||||||
|
in.data<T>(),
|
||||||
|
out.data<U>(),
|
||||||
|
axis_size,
|
||||||
|
stride,
|
||||||
|
stride_blocks);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
Reference in New Issue
Block a user