Fix compilation with CUDA 11 (#2331)

This commit is contained in:
Cheng
2025-07-08 12:00:43 +09:00
committed by GitHub
parent 4a9b29a875
commit 2ca533b279
11 changed files with 115 additions and 56 deletions

View File

@@ -37,15 +37,15 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
for (; i + block.size() * N <= check; i += block.size() * N) {
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
for (int j = 0; j < N; j++) {
accs[0] = op(accs[0], __cast<U, T>(vals[j]));
accs[0] = op(accs[0], cast_to<U>(vals[j]));
}
}
if (i < check) {
cub::LoadDirectBlocked(
block.thread_rank(), in + i, vals, check - i, __cast<T, U>(init));
block.thread_rank(), in + i, vals, check - i, cast_to<T>(init));
for (int i = 0; i < N; i++) {
accs[0] = op(accs[0], __cast<U, T>(vals[i]));
accs[0] = op(accs[0], cast_to<U>(vals[i]));
}
}