[CUDA] Use cuda::std::complex in place of cuComplex (#2372)

This commit is contained in:
Cheng
2025-07-15 16:36:13 +09:00
committed by GitHub
parent f0a0b077a0
commit cb349a291c
15 changed files with 169 additions and 460 deletions

View File

@@ -11,7 +11,6 @@
#include "mlx/array.h"
#include "mlx/backend/cuda/device/utils.cuh"
#include <cuComplex.h>
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
@@ -79,7 +78,7 @@ struct CTypeToCudaType<bfloat16_t> {
template <>
struct CTypeToCudaType<complex64_t> {
using type = cuComplex;
using type = cu::complex64_t;
};
template <typename T>
@@ -91,10 +90,14 @@ inline constexpr bool is_floating_v =
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
// Type traits for detecting complex numbers.
template <typename T>
inline constexpr bool is_complex_v = cuda::std::is_same_v<T, complex64_t> ||
cuda::std::is_same_v<T, complex128_t>;
// Type traits for detecting complex or real floating point numbers.
template <typename T>
inline constexpr bool is_inexact_v =
is_floating_v<T> || cuda::std::is_same_v<T, complex64_t>;
inline constexpr bool is_inexact_v = is_floating_v<T> || is_complex_v<T>;
// Utility to copy data from vector to array in host.
template <int NDIM = MAX_NDIM, typename T = int32_t>