Fix compilation with CUDA 11 (#2331)

This commit is contained in:
Cheng
2025-07-08 12:00:43 +09:00
committed by GitHub
parent 4a9b29a875
commit 2ca533b279
11 changed files with 115 additions and 56 deletions

View File

@@ -3,6 +3,8 @@
#pragma once
#include <cuComplex.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <thrust/iterator/transform_iterator.h>
namespace mlx::core::cu {
@@ -17,6 +19,26 @@ struct CastOp {
}
};
// Castings between complex and boolean.
// TODO: Should make a custom complex type.
template <>
struct CastOp<cuComplex, bool> {
static constexpr bool is_castable = true;
__device__ bool operator()(cuComplex x) {
return x.x != 0 && x.y != 0;
}
};
template <>
struct CastOp<bool, cuComplex> {
static constexpr bool is_castable = true;
__device__ cuComplex operator()(bool x) {
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0);
}
};
// Converting a complex number to real number discards the imaginary part.
template <typename DstT>
struct CastOp<
@@ -45,6 +67,7 @@ struct CastOp<
}
};
// Do nothing when no casting is needed.
template <typename SrcT, typename DstT>
struct CastOp<
SrcT,
@@ -57,9 +80,53 @@ struct CastOp<
}
};
// In CUDA 11 the half types do not define conversions between some types,
// provide fallbacks here.
#if CUDART_VERSION < 12000
template <typename SrcT, typename DstT>
struct CastOp<
SrcT,
DstT,
cuda::std::enable_if_t<
!cuda::std::is_convertible_v<SrcT, DstT> &&
!cuda::std::is_same_v<SrcT, cuComplex> &&
(cuda::std::is_same_v<DstT, __half> ||
cuda::std::is_same_v<DstT, __nv_bfloat16>)>> {
static constexpr bool is_castable = true;
__device__ DstT operator()(SrcT x) {
return DstT(static_cast<float>(x));
}
};
template <typename SrcT, typename DstT>
struct CastOp<
SrcT,
DstT,
cuda::std::enable_if_t<
!cuda::std::is_convertible_v<SrcT, DstT> &&
!cuda::std::is_same_v<DstT, cuComplex> &&
!cuda::std::is_same_v<DstT, __half> &&
!cuda::std::is_same_v<DstT, __nv_bfloat16> &&
(cuda::std::is_same_v<SrcT, __half> ||
cuda::std::is_same_v<SrcT, __nv_bfloat16>)>> {
static constexpr bool is_castable = true;
__device__ DstT operator()(SrcT x) {
return DstT(static_cast<float>(x));
}
};
#endif // CUDART_VERSION < 12000
// Helper to deduce the SrcT.
template <typename DstT, typename SrcT>
inline __host__ __device__ auto cast_to(SrcT x) {
return CastOp<SrcT, DstT>{}(x);
}
// Return an iterator that cast the value to DstT using CastOp.
template <typename DstT, typename Iterator>
__host__ __device__ auto make_cast_iterator(Iterator it) {
inline __host__ __device__ auto make_cast_iterator(Iterator it) {
using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
if constexpr (std::is_same_v<SrcT, DstT>) {
return it;

View File

@@ -99,20 +99,20 @@ struct Limits<
return cuda::std::numeric_limits<T>::infinity();
}
static constexpr __host__ __device__ T min() {
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
return -cuda::std::numeric_limits<T>::infinity();
#else
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
return -cuda::std::numeric_limits<float>::infinity();
#else
return -cuda::std::numeric_limits<T>::infinity();
#endif
}
static constexpr __host__ __device__ T finite_max() {
return cuda::std::numeric_limits<T>::max();
}
static constexpr __host__ __device__ T finite_min() {
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
return cuda::std::numeric_limits<T>::lowest();
#else
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
return cuda::std::numeric_limits<float>::lowest();
#else
return cuda::std::numeric_limits<T>::lowest();
#endif
}
};