Fix compilation with CUDA 11 (#2331)

This commit is contained in:
Cheng
2025-07-08 12:00:43 +09:00
committed by GitHub
parent 4a9b29a875
commit 2ca533b279
11 changed files with 115 additions and 56 deletions

View File

@@ -2,6 +2,8 @@
#pragma once
#include "mlx/backend/cuda/device/atomic_ops.cuh"
#include "mlx/backend/cuda/device/cast_op.cuh"
#include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/reduce/reduce_utils.cuh"
@@ -40,15 +42,15 @@ struct Sum {
}
__device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) {
atomicAdd(x, y);
atomic_add(x, y);
}
__device__ void atomic_update(int* x, int y) {
atomicAdd(x, y);
atomic_add(x, y);
}
__device__ void atomic_update(float* x, float y) {
atomicAdd(x, y);
atomic_add(x, y);
}
};
@@ -152,7 +154,7 @@ struct ReduceInit<Sum, T> {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return T{0, 0};
} else {
return typename ReduceResult<Sum, T>::type{0};
return cast_to<typename ReduceResult<Sum, T>::type>(0);
}
}
};
@@ -163,7 +165,7 @@ struct ReduceInit<Prod, T> {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return T{1, 0};
} else {
return typename ReduceResult<Prod, T>::type{1};
return cast_to<typename ReduceResult<Prod, T>::type>(1);
}
}
};