mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix compilation with CUDA 11 (#2331)
This commit is contained in:
@@ -3,6 +3,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <thrust/iterator/transform_iterator.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
@@ -17,6 +19,26 @@ struct CastOp {
|
||||
}
|
||||
};
|
||||
|
||||
// Castings between complex and boolean.
|
||||
// TODO: Should make a custom complex type.
|
||||
template <>
|
||||
struct CastOp<cuComplex, bool> {
|
||||
static constexpr bool is_castable = true;
|
||||
|
||||
__device__ bool operator()(cuComplex x) {
|
||||
return x.x != 0 && x.y != 0;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CastOp<bool, cuComplex> {
|
||||
static constexpr bool is_castable = true;
|
||||
|
||||
__device__ cuComplex operator()(bool x) {
|
||||
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
// Converting a complex number to real number discards the imaginary part.
|
||||
template <typename DstT>
|
||||
struct CastOp<
|
||||
@@ -45,6 +67,7 @@ struct CastOp<
|
||||
}
|
||||
};
|
||||
|
||||
// Do nothing when no casting is needed.
|
||||
template <typename SrcT, typename DstT>
|
||||
struct CastOp<
|
||||
SrcT,
|
||||
@@ -57,9 +80,53 @@ struct CastOp<
|
||||
}
|
||||
};
|
||||
|
||||
// In CUDA 11 the half types do not define conversions between some types,
|
||||
// provide fallbacks here.
|
||||
#if CUDART_VERSION < 12000
|
||||
template <typename SrcT, typename DstT>
|
||||
struct CastOp<
|
||||
SrcT,
|
||||
DstT,
|
||||
cuda::std::enable_if_t<
|
||||
!cuda::std::is_convertible_v<SrcT, DstT> &&
|
||||
!cuda::std::is_same_v<SrcT, cuComplex> &&
|
||||
(cuda::std::is_same_v<DstT, __half> ||
|
||||
cuda::std::is_same_v<DstT, __nv_bfloat16>)>> {
|
||||
static constexpr bool is_castable = true;
|
||||
|
||||
__device__ DstT operator()(SrcT x) {
|
||||
return DstT(static_cast<float>(x));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
struct CastOp<
|
||||
SrcT,
|
||||
DstT,
|
||||
cuda::std::enable_if_t<
|
||||
!cuda::std::is_convertible_v<SrcT, DstT> &&
|
||||
!cuda::std::is_same_v<DstT, cuComplex> &&
|
||||
!cuda::std::is_same_v<DstT, __half> &&
|
||||
!cuda::std::is_same_v<DstT, __nv_bfloat16> &&
|
||||
(cuda::std::is_same_v<SrcT, __half> ||
|
||||
cuda::std::is_same_v<SrcT, __nv_bfloat16>)>> {
|
||||
static constexpr bool is_castable = true;
|
||||
|
||||
__device__ DstT operator()(SrcT x) {
|
||||
return DstT(static_cast<float>(x));
|
||||
}
|
||||
};
|
||||
#endif // CUDART_VERSION < 12000
|
||||
|
||||
// Helper to deduce the SrcT.
|
||||
template <typename DstT, typename SrcT>
|
||||
inline __host__ __device__ auto cast_to(SrcT x) {
|
||||
return CastOp<SrcT, DstT>{}(x);
|
||||
}
|
||||
|
||||
// Return an iterator that cast the value to DstT using CastOp.
|
||||
template <typename DstT, typename Iterator>
|
||||
__host__ __device__ auto make_cast_iterator(Iterator it) {
|
||||
inline __host__ __device__ auto make_cast_iterator(Iterator it) {
|
||||
using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
|
||||
if constexpr (std::is_same_v<SrcT, DstT>) {
|
||||
return it;
|
||||
|
||||
@@ -99,20 +99,20 @@ struct Limits<
|
||||
return cuda::std::numeric_limits<T>::infinity();
|
||||
}
|
||||
static constexpr __host__ __device__ T min() {
|
||||
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
|
||||
return -cuda::std::numeric_limits<T>::infinity();
|
||||
#else
|
||||
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||
return -cuda::std::numeric_limits<float>::infinity();
|
||||
#else
|
||||
return -cuda::std::numeric_limits<T>::infinity();
|
||||
#endif
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_min() {
|
||||
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
|
||||
return cuda::std::numeric_limits<T>::lowest();
|
||||
#else
|
||||
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||
return cuda::std::numeric_limits<float>::lowest();
|
||||
#else
|
||||
return cuda::std::numeric_limits<T>::lowest();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user