Fix compilation with CUDA 11 (#2331)

This commit is contained in:
Cheng
2025-07-08 12:00:43 +09:00
committed by GitHub
parent 4a9b29a875
commit 2ca533b279
11 changed files with 115 additions and 56 deletions

View File

@@ -3,7 +3,6 @@
#include <numeric>
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/cast_op.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh"
#include <cooperative_groups.h>
@@ -128,7 +127,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
T vals[N_READS];
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) {
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
totals[i] = op(totals[i], cast_to<U>(vals[i]));
}
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
}
@@ -137,7 +136,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
T vals[N_READS];
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) {
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
totals[i] = op(totals[i], cast_to<U>(vals[i]));
}
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
}
@@ -150,9 +149,9 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
in + loop.location(),
vals,
args.reduction_stride - tile_x * BN,
__cast<T, U>(ReduceInit<Op, T>::value()));
cast_to<T>(ReduceInit<Op, T>::value()));
for (int i = 0; i < N_READS; i++) {
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
totals[i] = op(totals[i], cast_to<U>(vals[i]));
}
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
}