mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix compilation with CUDA 11 (#2331)
This commit is contained in:
@@ -37,15 +37,15 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
|
||||
for (; i + block.size() * N <= check; i += block.size() * N) {
|
||||
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
|
||||
for (int j = 0; j < N; j++) {
|
||||
accs[0] = op(accs[0], __cast<U, T>(vals[j]));
|
||||
accs[0] = op(accs[0], cast_to<U>(vals[j]));
|
||||
}
|
||||
}
|
||||
|
||||
if (i < check) {
|
||||
cub::LoadDirectBlocked(
|
||||
block.thread_rank(), in + i, vals, check - i, __cast<T, U>(init));
|
||||
block.thread_rank(), in + i, vals, check - i, cast_to<T>(init));
|
||||
for (int i = 0; i < N; i++) {
|
||||
accs[0] = op(accs[0], __cast<U, T>(vals[i]));
|
||||
accs[0] = op(accs[0], cast_to<U>(vals[i]));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user