mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Use cuda::std::complex in place of cuComplex (#2372)
This commit is contained in:
@@ -11,7 +11,6 @@
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
@@ -79,7 +78,7 @@ struct CTypeToCudaType<bfloat16_t> {
|
||||
|
||||
template <>
|
||||
struct CTypeToCudaType<complex64_t> {
|
||||
using type = cuComplex;
|
||||
using type = cu::complex64_t;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -91,10 +90,14 @@ inline constexpr bool is_floating_v =
|
||||
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
|
||||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
|
||||
|
||||
// Type traits for detecting complex numbers.
|
||||
template <typename T>
|
||||
inline constexpr bool is_complex_v = cuda::std::is_same_v<T, complex64_t> ||
|
||||
cuda::std::is_same_v<T, complex128_t>;
|
||||
|
||||
// Type traits for detecting complex or real floating point numbers.
|
||||
template <typename T>
|
||||
inline constexpr bool is_inexact_v =
|
||||
is_floating_v<T> || cuda::std::is_same_v<T, complex64_t>;
|
||||
inline constexpr bool is_inexact_v = is_floating_v<T> || is_complex_v<T>;
|
||||
|
||||
// Utility to copy data from vector to array in host.
|
||||
template <int NDIM = MAX_NDIM, typename T = int32_t>
|
||||
|
||||
Reference in New Issue
Block a user