mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix compilation with CUDA 11 (#2331)
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
@@ -128,7 +127,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
|
||||
totals[i] = op(totals[i], cast_to<U>(vals[i]));
|
||||
}
|
||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
@@ -137,7 +136,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
|
||||
totals[i] = op(totals[i], cast_to<U>(vals[i]));
|
||||
}
|
||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
@@ -150,9 +149,9 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||
in + loop.location(),
|
||||
vals,
|
||||
args.reduction_stride - tile_x * BN,
|
||||
__cast<T, U>(ReduceInit<Op, T>::value()));
|
||||
cast_to<T>(ReduceInit<Op, T>::value()));
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
|
||||
totals[i] = op(totals[i], cast_to<U>(vals[i]));
|
||||
}
|
||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user