[CUDA] Use cuda::std::complex in place of cuComplex (#2372)

This commit is contained in:
Cheng 2025-07-15 16:36:13 +09:00 committed by GitHub
parent f0a0b077a0
commit cb349a291c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 169 additions and 460 deletions

View File

@ -3,7 +3,6 @@
#include "mlx/backend/common/binary.h" #include "mlx/backend/common/binary.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/binary_ops.cuh" #include "mlx/backend/cuda/device/binary_ops.cuh"
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"

View File

@ -3,7 +3,6 @@
#include "mlx/backend/common/binary.h" #include "mlx/backend/common/binary.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/binary_ops.cuh" #include "mlx/backend/cuda/device/binary_ops.cuh"
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"

View File

@ -2,7 +2,7 @@
#pragma once #pragma once
#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/device/complex.cuh"
#include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/device/fp16_math.cuh"
#include <cuda/atomic> #include <cuda/atomic>
@ -48,7 +48,7 @@ inline __device__ void atomic_add(__half* out, __half val) {
atomicAdd(out, val); atomicAdd(out, val);
} }
inline __device__ void atomic_add(cuComplex* out, cuComplex val) { inline __device__ void atomic_add(complex64_t* out, complex64_t val) {
#if __CUDA_ARCH__ < 900 #if __CUDA_ARCH__ < 900
atomic_add_general(out, val); atomic_add_general(out, val);
#else #else

View File

@ -44,7 +44,7 @@ struct Remainder {
} else { } else {
return x % y; return x % y;
} }
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) { } else if constexpr (is_complex_v<T>) {
return x % y; return x % y;
} else { } else {
T r = fmod(x, y); T r = fmod(x, y);
@ -66,14 +66,12 @@ struct Equal {
struct NaNEqual { struct NaNEqual {
template <typename T> template <typename T>
__device__ bool operator()(T x, T y) { __device__ bool operator()(T x, T y) {
if constexpr (std::is_same_v<T, cuComplex>) { if constexpr (is_complex_v<T>) {
return x == y || return x == y ||
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) && isnan(cuCimagf(x)) && (isnan(x.real()) && isnan(y.real()) && isnan(x.imag()) &&
isnan(cuCimagf(y))) || isnan(y.imag())) ||
(cuCrealf(x) == cuCrealf(y) && isnan(cuCimagf(x)) && (x.real() == y.real() && isnan(x.imag()) && isnan(y.imag())) ||
isnan(cuCimagf(y))) || (isnan(x.real()) && isnan(y.real()) && x.imag() == y.imag());
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) &&
cuCimagf(x) == cuCimagf(y));
} else { } else {
return x == y || (isnan(x) && isnan(y)); return x == y || (isnan(x) && isnan(y));
} }
@ -111,17 +109,17 @@ struct LessEqual {
struct LogAddExp { struct LogAddExp {
template <typename T> template <typename T>
__device__ T operator()(T x, T y) { __device__ T operator()(T x, T y) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) { if constexpr (is_complex_v<T>) {
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) || if (isnan(x.real()) || isnan(x.imag()) || isnan(y.real()) ||
isnan(cuCimagf(y))) { isnan(y.imag())) {
return { return {
cuda::std::numeric_limits<float>::quiet_NaN(), cuda::std::numeric_limits<float>::quiet_NaN(),
cuda::std::numeric_limits<float>::quiet_NaN()}; cuda::std::numeric_limits<float>::quiet_NaN()};
} }
auto max = cuCrealf(x) > cuCrealf(y) ? x : y; auto max = x.real() > y.real() ? x : y;
auto min = cuCrealf(x) < cuCrealf(y) ? x : y; auto min = x.real() < y.real() ? x : y;
auto min_real = cuCrealf(min); auto min_real = min.real();
auto max_real = cuCrealf(max); auto max_real = max.real();
if (!isfinite(min_real) && (min_real == max_real)) { if (!isfinite(min_real) && (min_real == max_real)) {
if (min_real < 0) { if (min_real < 0) {
return min; return min;
@ -150,8 +148,8 @@ struct Maximum {
__device__ T operator()(T x, T y) { __device__ T operator()(T x, T y) {
if constexpr (cuda::std::is_integral_v<T>) { if constexpr (cuda::std::is_integral_v<T>) {
return max(x, y); return max(x, y);
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) { } else if constexpr (is_complex_v<T>) {
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) { if (isnan(x.real()) || isnan(x.imag())) {
return x; return x;
} }
return x > y ? x : y; return x > y ? x : y;
@ -169,8 +167,8 @@ struct Minimum {
__device__ T operator()(T x, T y) { __device__ T operator()(T x, T y) {
if constexpr (cuda::std::is_integral_v<T>) { if constexpr (cuda::std::is_integral_v<T>) {
return min(x, y); return min(x, y);
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) { } else if constexpr (is_complex_v<T>) {
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) { if (isnan(x.real()) || isnan(x.imag())) {
return x; return x;
} }
return x < y ? x : y; return x < y ? x : y;
@ -193,8 +191,8 @@ struct Multiply {
struct NotEqual { struct NotEqual {
template <typename T> template <typename T>
__device__ bool operator()(T x, T y) { __device__ bool operator()(T x, T y) {
if constexpr (std::is_same_v<T, cuComplex>) { if constexpr (is_complex_v<T>) {
return cuCrealf(x) != cuCrealf(y) || cuCimagf(x) != cuCimagf(y); return x.real() != y.real() || x.imag() != y.imag();
} else { } else {
return x != y; return x != y;
} }
@ -214,19 +212,8 @@ struct Power {
base *= base; base *= base;
} }
return res; return res;
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) { } else if constexpr (is_complex_v<T>) {
if (base.y == 0 && base.x == 0) { return pow(base, exp);
if (isnan(exp.x) || isnan(exp.y)) {
auto nan = cuda::std::numeric_limits<float>::quiet_NaN();
return make_cuFloatComplex(nan, nan);
}
return make_cuFloatComplex(0.0, 0.0);
}
auto x_theta = atan2f(base.y, base.x);
auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y);
auto mag = expf(exp.x * x_ln_r - exp.y * x_theta);
auto phase = exp.y * x_ln_r + exp.x * x_theta;
return make_cuFloatComplex(mag * cosf(phase), mag * sinf(phase));
} else { } else {
return powf(base, exp); return powf(base, exp);
} }

View File

@ -2,7 +2,8 @@
#pragma once #pragma once
#include <cuComplex.h> #include "mlx/backend/cuda/device/complex.cuh"
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <thrust/iterator/transform_iterator.h> #include <thrust/iterator/transform_iterator.h>
@ -20,50 +21,43 @@ struct CastOp {
}; };
// Castings between complex and boolean. // Castings between complex and boolean.
// TODO: Should make a custom complex type. template <typename T>
template <> struct CastOp<complex_t<T>, bool> {
struct CastOp<cuComplex, bool> {
static constexpr bool is_castable = true; static constexpr bool is_castable = true;
__device__ bool operator()(cuComplex x) { __device__ bool operator()(complex_t<T> x) {
return x.x != 0 && x.y != 0; return x.real() != 0 && x.imag() != 0;
} }
}; };
template <> template <typename T>
struct CastOp<bool, cuComplex> { struct CastOp<bool, complex_t<T>> {
static constexpr bool is_castable = true; static constexpr bool is_castable = true;
__device__ cuComplex operator()(bool x) { __device__ complex_t<T> operator()(bool x) {
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0); return x ? complex_t<T>{1, 1} : complex_t<T>{0, 0};
} }
}; };
// Converting a complex number to real number discards the imaginary part. // Converting a complex number to real number discards the imaginary part.
template <typename DstT> template <typename T, typename DstT>
struct CastOp< struct CastOp<complex_t<T>, DstT, cuda::std::enable_if_t<!is_complex_v<DstT>>> {
cuComplex, static constexpr bool is_castable = cuda::std::is_convertible_v<T, DstT>;
DstT,
cuda::std::enable_if_t<!cuda::std::is_same_v<cuComplex, DstT>>> {
static constexpr bool is_castable = cuda::std::is_convertible_v<float, DstT>;
__device__ DstT operator()(cuComplex x) { __device__ DstT operator()(complex_t<T> x) {
static_assert(!cuda::std::is_same_v<cuComplex, DstT>); static_assert(!is_complex_v<DstT>);
return static_cast<DstT>(cuCrealf(x)); return static_cast<DstT>(x.real());
} }
}; };
// Allow converting a real number to complex number. // Allow converting a real number to complex number.
template <typename SrcT> template <typename SrcT, typename T>
struct CastOp< struct CastOp<SrcT, complex_t<T>, cuda::std::enable_if_t<!is_complex_v<SrcT>>> {
SrcT, static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, T>;
cuComplex,
cuda::std::enable_if_t<!cuda::std::is_same_v<SrcT, cuComplex>>> {
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, float>;
__device__ cuComplex operator()(SrcT x) { __device__ complex_t<T> operator()(SrcT x) {
static_assert(!cuda::std::is_same_v<SrcT, cuComplex>); static_assert(!is_complex_v<SrcT>);
return cuComplex{static_cast<float>(x), 0}; return complex_t<T>{static_cast<T>(x), 0};
} }
}; };
@ -88,8 +82,7 @@ struct CastOp<
SrcT, SrcT,
DstT, DstT,
cuda::std::enable_if_t< cuda::std::enable_if_t<
!cuda::std::is_convertible_v<SrcT, DstT> && !cuda::std::is_convertible_v<SrcT, DstT> && !is_complex_v<SrcT> &&
!cuda::std::is_same_v<SrcT, cuComplex> &&
(cuda::std::is_same_v<DstT, __half> || (cuda::std::is_same_v<DstT, __half> ||
cuda::std::is_same_v<DstT, __nv_bfloat16>)>> { cuda::std::is_same_v<DstT, __nv_bfloat16>)>> {
static constexpr bool is_castable = true; static constexpr bool is_castable = true;
@ -104,8 +97,7 @@ struct CastOp<
SrcT, SrcT,
DstT, DstT,
cuda::std::enable_if_t< cuda::std::enable_if_t<
!cuda::std::is_convertible_v<SrcT, DstT> && !cuda::std::is_convertible_v<SrcT, DstT> && !is_complex_v<SrcT> &&
!cuda::std::is_same_v<DstT, cuComplex> &&
!cuda::std::is_same_v<DstT, __half> && !cuda::std::is_same_v<DstT, __half> &&
!cuda::std::is_same_v<DstT, __nv_bfloat16> && !cuda::std::is_same_v<DstT, __nv_bfloat16> &&
(cuda::std::is_same_v<SrcT, __half> || (cuda::std::is_same_v<SrcT, __half> ||

View File

@ -0,0 +1,61 @@
// Copyright © 2025 Apple Inc.
#pragma once
// Make multiplication and division faster.
#define LIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS
#include <cuda/std/complex>
#include <cuda/std/type_traits>
namespace mlx::core::cu {
// TODO: Consider using a faster implementation as cuda::std::complex has to
// conform to C++ standard.
template <typename T>
using complex_t = cuda::std::complex<T>;
using complex64_t = complex_t<float>;
using complex128_t = complex_t<double>;
template <typename T>
struct is_complex : cuda::std::false_type {};
template <typename T>
struct is_complex<cuda::std::complex<T>> : cuda::std::true_type {};
template <typename T>
inline constexpr bool is_complex_v = is_complex<T>::value;
// cuda::std::complex is missing some operators.
template <typename T>
inline __host__ __device__ complex_t<T> operator%(
complex_t<T> a,
complex_t<T> b) {
T r = a.real() - floor(a.real() / b.real()) * b.real();
T i = a.imag() - floor(a.imag() / b.imag()) * b.imag();
return complex_t<T>{r, i};
}
template <typename T>
inline __host__ __device__ bool operator<(complex_t<T> a, complex_t<T> b) {
return (a.real() * a.real() + a.imag() * a.imag()) <
(b.real() * b.real() + b.imag() * b.imag());
}
template <typename T>
inline __host__ __device__ bool operator>(complex_t<T> a, complex_t<T> b) {
return b < a;
}
template <typename T>
inline __host__ __device__ bool operator<=(complex_t<T> a, complex_t<T> b) {
return !(a > b);
}
template <typename T>
inline __host__ __device__ bool operator>=(complex_t<T> a, complex_t<T> b) {
return !(a < b);
}
} // namespace mlx::core::cu

View File

@ -1,240 +0,0 @@
// Copyright © 2025 Apple Inc.
// Copyright © 2017-2024 The Simons Foundation, Inc.
//
// FINUFFT is licensed under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance with the
// License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Forked from
// https://github.com/flatironinstitute/finufft/blob/main/include/cufinufft/contrib/helper_math.h
#pragma once
#include <cuComplex.h>
// This header provides some helper functions for cuComplex types.
// It mainly wraps existing CUDA implementations to provide operator overloads
// e.g. cuAdd, cuSub, cuMul, cuDiv, cuCreal, cuCimag, cuCabs, cuCarg, cuConj are
// all provided by CUDA
__forceinline__ __host__ __device__ cuDoubleComplex
operator+(const cuDoubleComplex& a, const cuDoubleComplex& b) {
return cuCadd(a, b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator-(const cuDoubleComplex& a, const cuDoubleComplex& b) {
return cuCsub(a, b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator*(const cuDoubleComplex& a, const cuDoubleComplex& b) {
return cuCmul(a, b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator/(const cuDoubleComplex& a, const cuDoubleComplex& b) {
return cuCdiv(a, b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator%(const cuDoubleComplex& a, const cuDoubleComplex& b) {
double r = cuCreal(a) - (floorf(cuCreal(a) / cuCreal(b)) * cuCreal(b));
double i = cuCimag(a) - (floorf(cuCimag(a) / cuCimag(b)) * cuCimag(b));
return make_cuDoubleComplex(r, i);
}
__forceinline__ __host__ __device__ bool operator==(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return cuCreal(a) == cuCreal(b) && cuCimag(a) == cuCimag(b);
}
__forceinline__ __host__ __device__ bool operator!=(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return !(a == b);
}
__forceinline__ __host__ __device__ bool operator>(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
double mag_a = sqrt(cuCreal(a) * cuCreal(a) + cuCimag(a) * cuCimag(a));
double mag_b = sqrt(cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b));
return mag_a > mag_b;
}
__forceinline__ __host__ __device__ bool operator>=(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return a > b || a == b;
}
__forceinline__ __host__ __device__ bool operator<(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return b > a;
}
__forceinline__ __host__ __device__ bool operator<=(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return b > a || a == b;
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator+(const cuDoubleComplex& a, double b) {
return make_cuDoubleComplex(cuCreal(a) + b, cuCimag(a));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator+(double a, const cuDoubleComplex& b) {
return make_cuDoubleComplex(a + cuCreal(b), cuCimag(b));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator-(const cuDoubleComplex& a, double b) {
return make_cuDoubleComplex(cuCreal(a) - b, cuCimag(a));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator-(double a, const cuDoubleComplex& b) {
return make_cuDoubleComplex(a - cuCreal(b), -cuCimag(b));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator*(const cuDoubleComplex& a, double b) {
return make_cuDoubleComplex(cuCreal(a) * b, cuCimag(a) * b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator*(double a, const cuDoubleComplex& b) {
return make_cuDoubleComplex(a * cuCreal(b), a * cuCimag(b));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator/(const cuDoubleComplex& a, double b) {
return make_cuDoubleComplex(cuCreal(a) / b, cuCimag(a) / b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator/(double a, const cuDoubleComplex& b) {
double denom = cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b);
return make_cuDoubleComplex(
(a * cuCreal(b)) / denom, (-a * cuCimag(b)) / denom);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator+(const cuFloatComplex& a, const cuFloatComplex& b) {
return cuCaddf(a, b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator-(const cuFloatComplex& a, const cuFloatComplex& b) {
return cuCsubf(a, b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator*(const cuFloatComplex& a, const cuFloatComplex& b) {
return cuCmulf(a, b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator/(const cuFloatComplex& a, const cuFloatComplex& b) {
return cuCdivf(a, b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator%(const cuFloatComplex& a, const cuFloatComplex& b) {
float r = cuCrealf(a) - (floorf(cuCrealf(a) / cuCrealf(b)) * cuCrealf(b));
float i = cuCimagf(a) - (floorf(cuCimagf(a) / cuCimagf(b)) * cuCimagf(b));
return make_cuFloatComplex(r, i);
}
__forceinline__ __host__ __device__ bool operator==(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return cuCrealf(a) == cuCrealf(b) && cuCimagf(a) == cuCimagf(b);
}
__forceinline__ __host__ __device__ bool operator!=(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return !(a == b);
}
__forceinline__ __host__ __device__ bool operator>(
const cuFloatComplex& a,
const cuFloatComplex& b) {
float mag_a = sqrt(cuCrealf(a) * cuCrealf(a) + cuCimagf(a) * cuCimagf(a));
float mag_b = sqrt(cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b));
return mag_a > mag_b;
}
__forceinline__ __host__ __device__ bool operator>=(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return a > b || a == b;
}
__forceinline__ __host__ __device__ bool operator<(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return b > a;
}
__forceinline__ __host__ __device__ bool operator<=(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return b > a || a == b;
}
__forceinline__ __host__ __device__ cuFloatComplex
operator+(const cuFloatComplex& a, float b) {
return make_cuFloatComplex(cuCrealf(a) + b, cuCimagf(a));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator+(float a, const cuFloatComplex& b) {
return make_cuFloatComplex(a + cuCrealf(b), cuCimagf(b));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator-(const cuFloatComplex& a, float b) {
return make_cuFloatComplex(cuCrealf(a) - b, cuCimagf(a));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator-(float a, const cuFloatComplex& b) {
return make_cuFloatComplex(a - cuCrealf(b), -cuCimagf(b));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator*(const cuFloatComplex& a, float b) {
return make_cuFloatComplex(cuCrealf(a) * b, cuCimagf(a) * b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator*(float a, const cuFloatComplex& b) {
return make_cuFloatComplex(a * cuCrealf(b), a * cuCimagf(b));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator/(const cuFloatComplex& a, float b) {
return make_cuFloatComplex(cuCrealf(a) / b, cuCimagf(a) / b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator/(float a, const cuFloatComplex& b) {
float denom = cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b);
return make_cuFloatComplex(
(a * cuCrealf(b)) / denom, (-a * cuCimagf(b)) / denom);
}

View File

@ -2,12 +2,10 @@
#pragma once #pragma once
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/device/utils.cuh"
#include <math_constants.h> #include <math_constants.h>
#include <cuda/std/complex>
namespace mlx::core::cu { namespace mlx::core::cu {
@ -16,8 +14,6 @@ struct Abs {
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (cuda::std::is_unsigned_v<T>) { if constexpr (cuda::std::is_unsigned_v<T>) {
return x; return x;
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {sqrt(cuCrealf(x) * cuCrealf(x) + cuCimagf(x) * cuCimagf(x)), 0};
} else { } else {
return abs(x); return abs(x);
} }
@ -29,8 +25,6 @@ struct ArcCos {
__device__ T operator()(T x) { __device__ T operator()(T x) {
return acos(x); return acos(x);
} }
__device__ cuComplex operator()(cuComplex x);
}; };
struct ArcCosh { struct ArcCosh {
@ -45,8 +39,6 @@ struct ArcSin {
__device__ T operator()(T x) { __device__ T operator()(T x) {
return asin(x); return asin(x);
} }
__device__ cuComplex operator()(cuComplex x);
}; };
struct ArcSinh { struct ArcSinh {
@ -61,8 +53,6 @@ struct ArcTan {
__device__ T operator()(T x) { __device__ T operator()(T x) {
return atan(x); return atan(x);
} }
__device__ cuComplex operator()(cuComplex x);
}; };
struct ArcTanh { struct ArcTanh {
@ -84,6 +74,8 @@ struct Ceil {
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (cuda::std::is_integral_v<T>) { if constexpr (cuda::std::is_integral_v<T>) {
return x; return x;
} else if constexpr (is_complex_v<T>) {
return T{ceil(x.real()), ceil(x.imag())};
} else { } else {
return ceil(x); return ceil(x);
} }
@ -91,34 +83,23 @@ struct Ceil {
}; };
struct Conjugate { struct Conjugate {
__device__ cuComplex operator()(cuComplex x) { template <typename T>
return {cuCrealf(x), -cuCimagf(x)}; __device__ complex_t<T> operator()(complex_t<T> x) {
return conj(x);
} }
}; };
struct Cos { struct Cos {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) { return cos(x);
return {
cos(cuCrealf(x)) * cosh(cuCimagf(x)),
-sin(cuCrealf(x)) * sinh(cuCimagf(x))};
} else {
return cos(x);
}
} }
}; };
struct Cosh { struct Cosh {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) { return cosh(x);
return {
cosh(cuCrealf(x)) * cos(cuCimagf(x)),
sinh(cuCrealf(x)) * sin(cuCimagf(x))};
} else {
return cosh(x);
}
} }
}; };
@ -151,12 +132,7 @@ struct ErfInv {
struct Exp { struct Exp {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) { return exp(x);
auto r = exp(cuda::std::complex<float>{cuCrealf(x), cuCimagf(x)});
return cuComplex{r.real(), r.imag()};
} else {
return exp(x);
}
} }
}; };
@ -178,6 +154,8 @@ struct Floor {
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (cuda::std::is_integral_v<T>) { if constexpr (cuda::std::is_integral_v<T>) {
return x; return x;
} else if constexpr (is_complex_v<T>) {
return T{floor(x.real()), floor(x.imag())};
} else { } else {
return floor(x); return floor(x);
} }
@ -185,30 +163,25 @@ struct Floor {
}; };
struct Imag { struct Imag {
__device__ float operator()(cuComplex x) { template <typename T>
return cuCimagf(x); __device__ auto operator()(complex_t<T> x) {
return x.imag();
} }
}; };
struct Log { struct Log {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) { return log(x);
auto r = log(cuCrealf(Abs{}(x)));
auto i = atan2f(cuCimagf(x), cuCrealf(x));
return {r, i};
} else {
return log(x);
}
} }
}; };
struct Log2 { struct Log2 {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) { if constexpr (is_complex_v<T>) {
auto y = Log{}(x); auto y = Log{}(x);
return {cuCrealf(y) / CUDART_LN2_F, cuCimagf(y) / CUDART_LN2_F}; return {y.real() / CUDART_LN2_F, y.imag() / CUDART_LN2_F};
} else { } else {
return log2(x); return log2(x);
} }
@ -218,23 +191,17 @@ struct Log2 {
struct Log10 { struct Log10 {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) { return log10(x);
auto y = Log{}(x);
return {cuCrealf(y) / CUDART_LNT_F, cuCimagf(y) / CUDART_LNT_F};
return y;
} else {
return log10(x);
}
} }
}; };
struct Log1p { struct Log1p {
template <typename T> template <typename T>
__device__ T operator()(T z) { __device__ T operator()(T z) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) { if constexpr (is_complex_v<T>) {
float x = cuCrealf(z); float x = z.real();
float y = cuCimagf(z); float y = z.imag();
float zabs = cuCrealf(Abs{}(z)); float zabs = Abs{}(z).real();
float theta = atan2f(y, x + 1); float theta = atan2f(y, x + 1);
if (zabs < 0.5f) { if (zabs < 0.5f) {
float r = x * (2 + x) + y * y; float r = x * (2 + x) + y * y;
@ -261,8 +228,8 @@ struct LogicalNot {
struct Negative { struct Negative {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) { if constexpr (is_complex_v<T>) {
return 0 - x; return T{0, 0} - x;
} else { } else {
return -x; return -x;
} }
@ -270,16 +237,17 @@ struct Negative {
}; };
struct Real { struct Real {
__device__ float operator()(cuComplex x) { template <typename T>
return cuCrealf(x); __device__ auto operator()(complex_t<T> x) {
return x.real();
} }
}; };
struct Round { struct Round {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) { if constexpr (is_complex_v<T>) {
return {rint(cuCrealf(x)), rint(cuCimagf(x))}; return {rint(x.real()), rint(x.imag())};
} else { } else {
return rint(x); return rint(x);
} }
@ -299,8 +267,8 @@ struct Sign {
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (cuda::std::is_unsigned_v<T>) { if constexpr (cuda::std::is_unsigned_v<T>) {
return x != 0; return x != 0;
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) { } else if constexpr (is_complex_v<T>) {
if (cuCrealf(x) == 0 && cuCimagf(x) == 0) { if (x.real() == 0 && x.imag() == 0) {
return x; return x;
} else { } else {
return x / Abs()(x); return x / Abs()(x);
@ -316,26 +284,14 @@ struct Sign {
struct Sin { struct Sin {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) { return sin(x);
return {
sin(cuCrealf(x)) * cosh(cuCimagf(x)),
cos(cuCrealf(x)) * sinh(cuCimagf(x))};
} else {
return sin(x);
}
} }
}; };
struct Sinh { struct Sinh {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) { return sinh(x);
return {
sinh(cuCrealf(x)) * cos(cuCimagf(x)),
cosh(cuCrealf(x)) * sin(cuCimagf(x))};
} else {
return sinh(x);
}
} }
}; };
@ -351,77 +307,31 @@ struct Sqrt {
__device__ T operator()(T x) { __device__ T operator()(T x) {
return sqrt(x); return sqrt(x);
} }
__device__ cuComplex operator()(cuComplex x) {
auto xr = cuCrealf(x);
auto xi = cuCimagf(x);
if (xr == 0.0f && xi == 0.0f) {
return {0.0f, 0.0f};
}
auto r = cuCrealf(Abs{}(x));
auto a = sqrt((r + xr) / 2.0f);
auto b_abs = sqrt((r - xr) / 2.0f);
auto b = copysign(b_abs, xi);
return {a, b};
}
}; };
struct Rsqrt { struct Rsqrt {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
return rsqrt(x); if constexpr (is_complex_v<T>) {
} return 1.0f / Sqrt{}(x);
__device__ cuComplex operator()(cuComplex x) { } else {
return 1.0f / Sqrt{}(x); return rsqrt(x);
}
} }
}; };
struct Tan { struct Tan {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) { return tan(x);
float tan_a = tan(cuCrealf(x));
float tanh_b = tanh(cuCimagf(x));
float t1 = tan_a * tanh_b;
float denom = 1. + t1 * t1;
return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
} else {
return tan(x);
}
} }
}; };
struct Tanh { struct Tanh {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) { return tanh(x);
float tanh_a = tanh(cuCrealf(x));
float tan_b = tan(cuCimagf(x));
float t1 = tanh_a * tan_b;
float denom = 1. + t1 * t1;
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
} else {
return tanh(x);
}
} }
}; };
inline __device__ cuComplex ArcCos::operator()(cuComplex x) {
auto i = cuComplex{0.0, 1.0};
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
return {cuCimagf(y), -cuCrealf(y)};
};
inline __device__ cuComplex ArcSin::operator()(cuComplex x) {
auto i = cuComplex{0.0f, 1.0f};
auto y = Log{}(i * x + Sqrt{}(1.0f - x * x));
return {cuCimagf(y), -cuCrealf(y)};
};
inline __device__ cuComplex ArcTan::operator()(cuComplex x) {
auto i = cuComplex{0.0f, 1.0f};
auto ix = i * x;
return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix));
};
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@ -8,9 +8,9 @@
#pragma once #pragma once
#include "mlx/backend/cuda/device/complex.cuh"
#include "mlx/backend/cuda/device/config.h" #include "mlx/backend/cuda/device/config.h"
#include <cuComplex.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda/std/array> #include <cuda/std/array>
@ -127,13 +127,13 @@ struct Limits<bool> {
} }
}; };
template <> template <typename T>
struct Limits<cuComplex> { struct Limits<complex_t<T>> {
static constexpr __host__ __device__ cuComplex max() { static constexpr __host__ __device__ complex_t<T> max() {
return {Limits<float>::max(), Limits<float>::max()}; return {Limits<T>::max(), Limits<T>::max()};
} }
static constexpr __host__ __device__ cuComplex min() { static constexpr __host__ __device__ complex_t<T> min() {
return {Limits<float>::min(), Limits<float>::min()}; return {Limits<T>::min(), Limits<T>::min()};
} }
}; };

View File

@ -173,7 +173,7 @@ constexpr const char* g_include_names[] = {
INCLUDE_PREFIX "binary_ops.cuh", INCLUDE_PREFIX "binary_ops.cuh",
INCLUDE_PREFIX "cast_op.cuh", INCLUDE_PREFIX "cast_op.cuh",
INCLUDE_PREFIX "config.h", INCLUDE_PREFIX "config.h",
INCLUDE_PREFIX "cucomplex_math.cuh", INCLUDE_PREFIX "complex.cuh",
INCLUDE_PREFIX "fp16_math.cuh", INCLUDE_PREFIX "fp16_math.cuh",
INCLUDE_PREFIX "indexing.cuh", INCLUDE_PREFIX "indexing.cuh",
INCLUDE_PREFIX "scatter_ops.cuh", INCLUDE_PREFIX "scatter_ops.cuh",
@ -189,7 +189,7 @@ constexpr const char* g_headers[] = {
jit_source_binary_ops, jit_source_binary_ops,
jit_source_cast_op, jit_source_cast_op,
jit_source_config, jit_source_config,
jit_source_cucomplex_math, jit_source_complex,
jit_source_fp16_math, jit_source_fp16_math,
jit_source_indexing, jit_source_indexing,
jit_source_scatter_ops, jit_source_scatter_ops,

View File

@ -11,7 +11,6 @@
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/device/utils.cuh"
#include <cuComplex.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
@ -79,7 +78,7 @@ struct CTypeToCudaType<bfloat16_t> {
template <> template <>
struct CTypeToCudaType<complex64_t> { struct CTypeToCudaType<complex64_t> {
using type = cuComplex; using type = cu::complex64_t;
}; };
template <typename T> template <typename T>
@ -91,10 +90,14 @@ inline constexpr bool is_floating_v =
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> || cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>; cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
// Type traits for detecting complex numbers.
template <typename T>
inline constexpr bool is_complex_v = cuda::std::is_same_v<T, complex64_t> ||
cuda::std::is_same_v<T, complex128_t>;
// Type traits for detecting complex or real floating point numbers. // Type traits for detecting complex or real floating point numbers.
template <typename T> template <typename T>
inline constexpr bool is_inexact_v = inline constexpr bool is_inexact_v = is_floating_v<T> || is_complex_v<T>;
is_floating_v<T> || cuda::std::is_same_v<T, complex64_t>;
// Utility to copy data from vector to array in host. // Utility to copy data from vector to array in host.
template <int NDIM = MAX_NDIM, typename T = int32_t> template <int NDIM = MAX_NDIM, typename T = int32_t>

View File

@ -3,7 +3,6 @@
#include <type_traits> #include <type_traits>
#include "mlx/backend/common/reduce.h" #include "mlx/backend/common/reduce.h"
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/reduce/reduce_ops.cuh" #include "mlx/backend/cuda/reduce/reduce_ops.cuh"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"

View File

@ -151,7 +151,7 @@ struct ReduceInit<Or, T> {
template <typename T> template <typename T>
struct ReduceInit<Sum, T> { struct ReduceInit<Sum, T> {
static constexpr __host__ __device__ auto value() { static constexpr __host__ __device__ auto value() {
if constexpr (cuda::std::is_same_v<T, cuComplex>) { if constexpr (is_complex_v<T>) {
return T{0, 0}; return T{0, 0};
} else { } else {
return cast_to<typename ReduceResult<Sum, T>::type>(0); return cast_to<typename ReduceResult<Sum, T>::type>(0);
@ -162,7 +162,7 @@ struct ReduceInit<Sum, T> {
template <typename T> template <typename T>
struct ReduceInit<Prod, T> { struct ReduceInit<Prod, T> {
static constexpr __host__ __device__ auto value() { static constexpr __host__ __device__ auto value() {
if constexpr (cuda::std::is_same_v<T, cuComplex>) { if constexpr (is_complex_v<T>) {
return T{1, 0}; return T{1, 0};
} else { } else {
return cast_to<typename ReduceResult<Prod, T>::type>(1); return cast_to<typename ReduceResult<Prod, T>::type>(1);

View File

@ -2,7 +2,6 @@
#include "mlx/backend/common/unary.h" #include "mlx/backend/common/unary.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/device/unary_ops.cuh" #include "mlx/backend/cuda/device/unary_ops.cuh"
#include "mlx/backend/cuda/iterators/general_iterator.cuh" #include "mlx/backend/cuda/iterators/general_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
@ -71,10 +70,10 @@ constexpr bool supports_unary_op() {
!std::is_same_v<In, bool>; !std::is_same_v<In, bool>;
} }
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor>) { if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor>) {
return std::is_same_v<In, Out> && !std::is_same_v<In, complex64_t>; return std::is_same_v<In, Out> && !mlx::core::is_complex_v<In>;
} }
if (std::is_same_v<Op, Conjugate>) { if (std::is_same_v<Op, Conjugate>) {
return std::is_same_v<In, Out> && std::is_same_v<In, complex64_t>; return std::is_same_v<In, Out> && mlx::core::is_complex_v<In>;
} }
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcSin> || if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcSin> ||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, Cos> || std::is_same_v<Op, ArcTan> || std::is_same_v<Op, Cos> ||
@ -88,7 +87,7 @@ constexpr bool supports_unary_op() {
return std::is_same_v<In, Out> && is_inexact_v<In>; return std::is_same_v<In, Out> && is_inexact_v<In>;
} }
if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) { if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {
return std::is_same_v<In, complex64_t> && std::is_same_v<Out, float>; return mlx::core::is_complex_v<In> && std::is_same_v<Out, float>;
} }
if (std::is_same_v<Op, LogicalNot>) { if (std::is_same_v<Op, LogicalNot>) {
return std::is_same_v<In, Out> && std::is_same_v<In, bool>; return std::is_same_v<In, Out> && std::is_same_v<In, bool>;

View File

@ -61,7 +61,7 @@ const char* dtype_to_cuda_type(const Dtype& dtype) {
case float64: case float64:
return "double"; return "double";
case complex64: case complex64:
return "cuComplex"; return "complex64_t";
default: default:
return "unknown"; return "unknown";
} }