[CUDA] Do vectorized store/load in contiguous elementwise ops (#2342)

* Do vectorized store/load in unary ops

* Do vectorized store/load in binary_two ops

* Do vectorized store/load in copy ops

* Do vectorized store/load in ternary ops

* Use int32_t for IdxT

* binary => binary_two in binary_two.cu

* Fix tests on large arrays

* Use uint as index type

* Contig uses uint as index and non-contig uses int
This commit is contained in:
Cheng
2025-07-10 10:48:43 +09:00
committed by GitHub
parent e14ee12491
commit 85873cb162
5 changed files with 223 additions and 96 deletions

View File

@@ -15,12 +15,27 @@ namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename T, typename IdxT>
template <typename Op, typename T, typename IdxT, int N_READS>
__global__ void
ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
out[index] = Op{}(a[index], b[index], c[index]);
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(a[i], b[i], c[i]);
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
auto b_vec = load_vector<N_READS>(b, index);
auto c_vec = load_vector<N_READS>(c, index);
AlignedVector<T, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i], c_vec.val[i]);
}
store_vector<N_READS>(out, index, out_vec);
}
}
@@ -149,11 +164,18 @@ void ternary_op_gpu_inplace(
}
});
} else {
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
auto kernel = cu::ternary_v<Op, DType, IdxT>;
// TODO: Choose optimized value based on type size.
constexpr int N_READS = 4;
auto kernel = cu::ternary_v<Op, DType, IdxT, N_READS>;
auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), large());
kernel,
out.data_size(),
out.shape(),
out.strides(),
large(),
N_READS);
encoder.add_kernel_node(
kernel,
num_blocks,