From cb349a291c4417ef29545f1f075f558d591de5f7 Mon Sep 17 00:00:00 2001 From: Cheng Date: Tue, 15 Jul 2025 16:36:13 +0900 Subject: [PATCH] [CUDA] Use cuda::std::complex in place of cuComplex (#2372) --- mlx/backend/cuda/binary.cu | 1 - mlx/backend/cuda/binary_two.cu | 1 - mlx/backend/cuda/device/atomic_ops.cuh | 4 +- mlx/backend/cuda/device/binary_ops.cuh | 55 ++--- mlx/backend/cuda/device/cast_op.cuh | 56 +++-- mlx/backend/cuda/device/complex.cuh | 61 ++++++ mlx/backend/cuda/device/cucomplex_math.cuh | 240 --------------------- mlx/backend/cuda/device/unary_ops.cuh | 168 ++++----------- mlx/backend/cuda/device/utils.cuh | 14 +- mlx/backend/cuda/jit_module.cpp | 4 +- mlx/backend/cuda/kernel_utils.cuh | 11 +- mlx/backend/cuda/reduce/reduce.cuh | 1 - mlx/backend/cuda/reduce/reduce_ops.cuh | 4 +- mlx/backend/cuda/unary.cu | 7 +- mlx/backend/cuda/utils.cpp | 2 +- 15 files changed, 169 insertions(+), 460 deletions(-) create mode 100644 mlx/backend/cuda/device/complex.cuh delete mode 100644 mlx/backend/cuda/device/cucomplex_math.cuh diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index c8586e638..3eade024d 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -3,7 +3,6 @@ #include "mlx/backend/common/binary.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/binary_ops.cuh" -#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 0918c579f..3ac8a9516 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -3,7 +3,6 @@ #include "mlx/backend/common/binary.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/binary_ops.cuh" -#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" diff --git a/mlx/backend/cuda/device/atomic_ops.cuh b/mlx/backend/cuda/device/atomic_ops.cuh index e0d3c3eac..5df246c0e 100644 --- a/mlx/backend/cuda/device/atomic_ops.cuh +++ b/mlx/backend/cuda/device/atomic_ops.cuh @@ -2,7 +2,7 @@ #pragma once -#include "mlx/backend/cuda/device/cucomplex_math.cuh" +#include "mlx/backend/cuda/device/complex.cuh" #include "mlx/backend/cuda/device/fp16_math.cuh" #include @@ -48,7 +48,7 @@ inline __device__ void atomic_add(__half* out, __half val) { atomicAdd(out, val); } -inline __device__ void atomic_add(cuComplex* out, cuComplex val) { +inline __device__ void atomic_add(complex64_t* out, complex64_t val) { #if __CUDA_ARCH__ < 900 atomic_add_general(out, val); #else diff --git a/mlx/backend/cuda/device/binary_ops.cuh b/mlx/backend/cuda/device/binary_ops.cuh index 644786a92..575aced14 100644 --- a/mlx/backend/cuda/device/binary_ops.cuh +++ b/mlx/backend/cuda/device/binary_ops.cuh @@ -44,7 +44,7 @@ struct Remainder { } else { return x % y; } - } else if constexpr (cuda::std::is_same_v) { + } else if constexpr (is_complex_v) { return x % y; } else { T r = fmod(x, y); @@ -66,14 +66,12 @@ struct Equal { struct NaNEqual { template __device__ bool operator()(T x, T y) { - if constexpr (std::is_same_v) { + if constexpr (is_complex_v) { return x == y || - (isnan(cuCrealf(x)) && isnan(cuCrealf(y)) && isnan(cuCimagf(x)) && - isnan(cuCimagf(y))) || - (cuCrealf(x) == cuCrealf(y) && isnan(cuCimagf(x)) && - isnan(cuCimagf(y))) || - (isnan(cuCrealf(x)) && isnan(cuCrealf(y)) && - cuCimagf(x) == cuCimagf(y)); + (isnan(x.real()) && isnan(y.real()) && isnan(x.imag()) && + isnan(y.imag())) || + (x.real() == y.real() && isnan(x.imag()) && isnan(y.imag())) || + (isnan(x.real()) && isnan(y.real()) && x.imag() == y.imag()); } else { return x == y || (isnan(x) && isnan(y)); } @@ -111,17 +109,17 @@ struct LessEqual { struct LogAddExp { template __device__ T operator()(T x, T y) { - if constexpr (cuda::std::is_same_v) { - if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) || - isnan(cuCimagf(y))) { + if constexpr (is_complex_v) { + if (isnan(x.real()) || isnan(x.imag()) || isnan(y.real()) || + isnan(y.imag())) { return { cuda::std::numeric_limits::quiet_NaN(), cuda::std::numeric_limits::quiet_NaN()}; } - auto max = cuCrealf(x) > cuCrealf(y) ? x : y; - auto min = cuCrealf(x) < cuCrealf(y) ? x : y; - auto min_real = cuCrealf(min); - auto max_real = cuCrealf(max); + auto max = x.real() > y.real() ? x : y; + auto min = x.real() < y.real() ? x : y; + auto min_real = min.real(); + auto max_real = max.real(); if (!isfinite(min_real) && (min_real == max_real)) { if (min_real < 0) { return min; @@ -150,8 +148,8 @@ struct Maximum { __device__ T operator()(T x, T y) { if constexpr (cuda::std::is_integral_v) { return max(x, y); - } else if constexpr (cuda::std::is_same_v) { - if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) { + } else if constexpr (is_complex_v) { + if (isnan(x.real()) || isnan(x.imag())) { return x; } return x > y ? x : y; @@ -169,8 +167,8 @@ struct Minimum { __device__ T operator()(T x, T y) { if constexpr (cuda::std::is_integral_v) { return min(x, y); - } else if constexpr (cuda::std::is_same_v) { - if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) { + } else if constexpr (is_complex_v) { + if (isnan(x.real()) || isnan(x.imag())) { return x; } return x < y ? x : y; @@ -193,8 +191,8 @@ struct Multiply { struct NotEqual { template __device__ bool operator()(T x, T y) { - if constexpr (std::is_same_v) { - return cuCrealf(x) != cuCrealf(y) || cuCimagf(x) != cuCimagf(y); + if constexpr (is_complex_v) { + return x.real() != y.real() || x.imag() != y.imag(); } else { return x != y; } @@ -214,19 +212,8 @@ struct Power { base *= base; } return res; - } else if constexpr (cuda::std::is_same_v) { - if (base.y == 0 && base.x == 0) { - if (isnan(exp.x) || isnan(exp.y)) { - auto nan = cuda::std::numeric_limits::quiet_NaN(); - return make_cuFloatComplex(nan, nan); - } - return make_cuFloatComplex(0.0, 0.0); - } - auto x_theta = atan2f(base.y, base.x); - auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y); - auto mag = expf(exp.x * x_ln_r - exp.y * x_theta); - auto phase = exp.y * x_ln_r + exp.x * x_theta; - return make_cuFloatComplex(mag * cosf(phase), mag * sinf(phase)); + } else if constexpr (is_complex_v) { + return pow(base, exp); } else { return powf(base, exp); } diff --git a/mlx/backend/cuda/device/cast_op.cuh b/mlx/backend/cuda/device/cast_op.cuh index 8da19ddf8..e10fde6dc 100644 --- a/mlx/backend/cuda/device/cast_op.cuh +++ b/mlx/backend/cuda/device/cast_op.cuh @@ -2,7 +2,8 @@ #pragma once -#include +#include "mlx/backend/cuda/device/complex.cuh" + #include #include #include @@ -20,50 +21,43 @@ struct CastOp { }; // Castings between complex and boolean. -// TODO: Should make a custom complex type. -template <> -struct CastOp { +template +struct CastOp, bool> { static constexpr bool is_castable = true; - __device__ bool operator()(cuComplex x) { - return x.x != 0 && x.y != 0; + __device__ bool operator()(complex_t x) { + return x.real() != 0 && x.imag() != 0; } }; -template <> -struct CastOp { +template +struct CastOp> { static constexpr bool is_castable = true; - __device__ cuComplex operator()(bool x) { - return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0); + __device__ complex_t operator()(bool x) { + return x ? complex_t{1, 1} : complex_t{0, 0}; } }; // Converting a complex number to real number discards the imaginary part. -template -struct CastOp< - cuComplex, - DstT, - cuda::std::enable_if_t>> { - static constexpr bool is_castable = cuda::std::is_convertible_v; +template +struct CastOp, DstT, cuda::std::enable_if_t>> { + static constexpr bool is_castable = cuda::std::is_convertible_v; - __device__ DstT operator()(cuComplex x) { - static_assert(!cuda::std::is_same_v); - return static_cast(cuCrealf(x)); + __device__ DstT operator()(complex_t x) { + static_assert(!is_complex_v); + return static_cast(x.real()); } }; // Allow converting a real number to complex number. -template -struct CastOp< - SrcT, - cuComplex, - cuda::std::enable_if_t>> { - static constexpr bool is_castable = cuda::std::is_convertible_v; +template +struct CastOp, cuda::std::enable_if_t>> { + static constexpr bool is_castable = cuda::std::is_convertible_v; - __device__ cuComplex operator()(SrcT x) { - static_assert(!cuda::std::is_same_v); - return cuComplex{static_cast(x), 0}; + __device__ complex_t operator()(SrcT x) { + static_assert(!is_complex_v); + return complex_t{static_cast(x), 0}; } }; @@ -88,8 +82,7 @@ struct CastOp< SrcT, DstT, cuda::std::enable_if_t< - !cuda::std::is_convertible_v && - !cuda::std::is_same_v && + !cuda::std::is_convertible_v && !is_complex_v && (cuda::std::is_same_v || cuda::std::is_same_v)>> { static constexpr bool is_castable = true; @@ -104,8 +97,7 @@ struct CastOp< SrcT, DstT, cuda::std::enable_if_t< - !cuda::std::is_convertible_v && - !cuda::std::is_same_v && + !cuda::std::is_convertible_v && !is_complex_v && !cuda::std::is_same_v && !cuda::std::is_same_v && (cuda::std::is_same_v || diff --git a/mlx/backend/cuda/device/complex.cuh b/mlx/backend/cuda/device/complex.cuh new file mode 100644 index 000000000..8dfd23b46 --- /dev/null +++ b/mlx/backend/cuda/device/complex.cuh @@ -0,0 +1,61 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +// Make multiplication and division faster. +#define LIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS + +#include +#include + +namespace mlx::core::cu { + +// TODO: Consider using a faster implementation as cuda::std::complex has to +// conform to C++ standard. +template +using complex_t = cuda::std::complex; + +using complex64_t = complex_t; +using complex128_t = complex_t; + +template +struct is_complex : cuda::std::false_type {}; + +template +struct is_complex> : cuda::std::true_type {}; + +template +inline constexpr bool is_complex_v = is_complex::value; + +// cuda::std::complex is missing some operators. +template +inline __host__ __device__ complex_t operator%( + complex_t a, + complex_t b) { + T r = a.real() - floor(a.real() / b.real()) * b.real(); + T i = a.imag() - floor(a.imag() / b.imag()) * b.imag(); + return complex_t{r, i}; +} + +template +inline __host__ __device__ bool operator<(complex_t a, complex_t b) { + return (a.real() * a.real() + a.imag() * a.imag()) < + (b.real() * b.real() + b.imag() * b.imag()); +} + +template +inline __host__ __device__ bool operator>(complex_t a, complex_t b) { + return b < a; +} + +template +inline __host__ __device__ bool operator<=(complex_t a, complex_t b) { + return !(a > b); +} + +template +inline __host__ __device__ bool operator>=(complex_t a, complex_t b) { + return !(a < b); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/cucomplex_math.cuh b/mlx/backend/cuda/device/cucomplex_math.cuh deleted file mode 100644 index 612650c06..000000000 --- a/mlx/backend/cuda/device/cucomplex_math.cuh +++ /dev/null @@ -1,240 +0,0 @@ -// Copyright © 2025 Apple Inc. -// Copyright © 2017-2024 The Simons Foundation, Inc. -// -// FINUFFT is licensed under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance with the -// License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// Forked from -// https://github.com/flatironinstitute/finufft/blob/main/include/cufinufft/contrib/helper_math.h - -#pragma once - -#include - -// This header provides some helper functions for cuComplex types. -// It mainly wraps existing CUDA implementations to provide operator overloads -// e.g. cuAdd, cuSub, cuMul, cuDiv, cuCreal, cuCimag, cuCabs, cuCarg, cuConj are -// all provided by CUDA - -__forceinline__ __host__ __device__ cuDoubleComplex -operator+(const cuDoubleComplex& a, const cuDoubleComplex& b) { - return cuCadd(a, b); -} - -__forceinline__ __host__ __device__ cuDoubleComplex -operator-(const cuDoubleComplex& a, const cuDoubleComplex& b) { - return cuCsub(a, b); -} - -__forceinline__ __host__ __device__ cuDoubleComplex -operator*(const cuDoubleComplex& a, const cuDoubleComplex& b) { - return cuCmul(a, b); -} - -__forceinline__ __host__ __device__ cuDoubleComplex -operator/(const cuDoubleComplex& a, const cuDoubleComplex& b) { - return cuCdiv(a, b); -} - -__forceinline__ __host__ __device__ cuDoubleComplex -operator%(const cuDoubleComplex& a, const cuDoubleComplex& b) { - double r = cuCreal(a) - (floorf(cuCreal(a) / cuCreal(b)) * cuCreal(b)); - double i = cuCimag(a) - (floorf(cuCimag(a) / cuCimag(b)) * cuCimag(b)); - return make_cuDoubleComplex(r, i); -} - -__forceinline__ __host__ __device__ bool operator==( - const cuDoubleComplex& a, - const cuDoubleComplex& b) { - return cuCreal(a) == cuCreal(b) && cuCimag(a) == cuCimag(b); -} - -__forceinline__ __host__ __device__ bool operator!=( - const cuDoubleComplex& a, - const cuDoubleComplex& b) { - return !(a == b); -} - -__forceinline__ __host__ __device__ bool operator>( - const cuDoubleComplex& a, - const cuDoubleComplex& b) { - double mag_a = sqrt(cuCreal(a) * cuCreal(a) + cuCimag(a) * cuCimag(a)); - double mag_b = sqrt(cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b)); - return mag_a > mag_b; -} - -__forceinline__ __host__ __device__ bool operator>=( - const cuDoubleComplex& a, - const cuDoubleComplex& b) { - return a > b || a == b; -} - -__forceinline__ __host__ __device__ bool operator<( - const cuDoubleComplex& a, - const cuDoubleComplex& b) { - return b > a; -} - -__forceinline__ __host__ __device__ bool operator<=( - const cuDoubleComplex& a, - const cuDoubleComplex& b) { - return b > a || a == b; -} - -__forceinline__ __host__ __device__ cuDoubleComplex -operator+(const cuDoubleComplex& a, double b) { - return make_cuDoubleComplex(cuCreal(a) + b, cuCimag(a)); -} - -__forceinline__ __host__ __device__ cuDoubleComplex -operator+(double a, const cuDoubleComplex& b) { - return make_cuDoubleComplex(a + cuCreal(b), cuCimag(b)); -} - -__forceinline__ __host__ __device__ cuDoubleComplex -operator-(const cuDoubleComplex& a, double b) { - return make_cuDoubleComplex(cuCreal(a) - b, cuCimag(a)); -} - -__forceinline__ __host__ __device__ cuDoubleComplex -operator-(double a, const cuDoubleComplex& b) { - return make_cuDoubleComplex(a - cuCreal(b), -cuCimag(b)); -} - -__forceinline__ __host__ __device__ cuDoubleComplex -operator*(const cuDoubleComplex& a, double b) { - return make_cuDoubleComplex(cuCreal(a) * b, cuCimag(a) * b); -} - -__forceinline__ __host__ __device__ cuDoubleComplex -operator*(double a, const cuDoubleComplex& b) { - return make_cuDoubleComplex(a * cuCreal(b), a * cuCimag(b)); -} - -__forceinline__ __host__ __device__ cuDoubleComplex -operator/(const cuDoubleComplex& a, double b) { - return make_cuDoubleComplex(cuCreal(a) / b, cuCimag(a) / b); -} - -__forceinline__ __host__ __device__ cuDoubleComplex -operator/(double a, const cuDoubleComplex& b) { - double denom = cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b); - return make_cuDoubleComplex( - (a * cuCreal(b)) / denom, (-a * cuCimag(b)) / denom); -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator+(const cuFloatComplex& a, const cuFloatComplex& b) { - return cuCaddf(a, b); -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator-(const cuFloatComplex& a, const cuFloatComplex& b) { - return cuCsubf(a, b); -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator*(const cuFloatComplex& a, const cuFloatComplex& b) { - return cuCmulf(a, b); -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator/(const cuFloatComplex& a, const cuFloatComplex& b) { - return cuCdivf(a, b); -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator%(const cuFloatComplex& a, const cuFloatComplex& b) { - float r = cuCrealf(a) - (floorf(cuCrealf(a) / cuCrealf(b)) * cuCrealf(b)); - float i = cuCimagf(a) - (floorf(cuCimagf(a) / cuCimagf(b)) * cuCimagf(b)); - return make_cuFloatComplex(r, i); -} - -__forceinline__ __host__ __device__ bool operator==( - const cuFloatComplex& a, - const cuFloatComplex& b) { - return cuCrealf(a) == cuCrealf(b) && cuCimagf(a) == cuCimagf(b); -} - -__forceinline__ __host__ __device__ bool operator!=( - const cuFloatComplex& a, - const cuFloatComplex& b) { - return !(a == b); -} - -__forceinline__ __host__ __device__ bool operator>( - const cuFloatComplex& a, - const cuFloatComplex& b) { - float mag_a = sqrt(cuCrealf(a) * cuCrealf(a) + cuCimagf(a) * cuCimagf(a)); - float mag_b = sqrt(cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b)); - return mag_a > mag_b; -} - -__forceinline__ __host__ __device__ bool operator>=( - const cuFloatComplex& a, - const cuFloatComplex& b) { - return a > b || a == b; -} - -__forceinline__ __host__ __device__ bool operator<( - const cuFloatComplex& a, - const cuFloatComplex& b) { - return b > a; -} - -__forceinline__ __host__ __device__ bool operator<=( - const cuFloatComplex& a, - const cuFloatComplex& b) { - return b > a || a == b; -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator+(const cuFloatComplex& a, float b) { - return make_cuFloatComplex(cuCrealf(a) + b, cuCimagf(a)); -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator+(float a, const cuFloatComplex& b) { - return make_cuFloatComplex(a + cuCrealf(b), cuCimagf(b)); -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator-(const cuFloatComplex& a, float b) { - return make_cuFloatComplex(cuCrealf(a) - b, cuCimagf(a)); -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator-(float a, const cuFloatComplex& b) { - return make_cuFloatComplex(a - cuCrealf(b), -cuCimagf(b)); -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator*(const cuFloatComplex& a, float b) { - return make_cuFloatComplex(cuCrealf(a) * b, cuCimagf(a) * b); -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator*(float a, const cuFloatComplex& b) { - return make_cuFloatComplex(a * cuCrealf(b), a * cuCimagf(b)); -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator/(const cuFloatComplex& a, float b) { - return make_cuFloatComplex(cuCrealf(a) / b, cuCimagf(a) / b); -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator/(float a, const cuFloatComplex& b) { - float denom = cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b); - return make_cuFloatComplex( - (a * cuCrealf(b)) / denom, (-a * cuCimagf(b)) / denom); -} diff --git a/mlx/backend/cuda/device/unary_ops.cuh b/mlx/backend/cuda/device/unary_ops.cuh index 447569eeb..aebed1e4d 100644 --- a/mlx/backend/cuda/device/unary_ops.cuh +++ b/mlx/backend/cuda/device/unary_ops.cuh @@ -2,12 +2,10 @@ #pragma once -#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/device/utils.cuh" #include -#include namespace mlx::core::cu { @@ -16,8 +14,6 @@ struct Abs { __device__ T operator()(T x) { if constexpr (cuda::std::is_unsigned_v) { return x; - } else if constexpr (cuda::std::is_same_v) { - return {sqrt(cuCrealf(x) * cuCrealf(x) + cuCimagf(x) * cuCimagf(x)), 0}; } else { return abs(x); } @@ -29,8 +25,6 @@ struct ArcCos { __device__ T operator()(T x) { return acos(x); } - - __device__ cuComplex operator()(cuComplex x); }; struct ArcCosh { @@ -45,8 +39,6 @@ struct ArcSin { __device__ T operator()(T x) { return asin(x); } - - __device__ cuComplex operator()(cuComplex x); }; struct ArcSinh { @@ -61,8 +53,6 @@ struct ArcTan { __device__ T operator()(T x) { return atan(x); } - - __device__ cuComplex operator()(cuComplex x); }; struct ArcTanh { @@ -84,6 +74,8 @@ struct Ceil { __device__ T operator()(T x) { if constexpr (cuda::std::is_integral_v) { return x; + } else if constexpr (is_complex_v) { + return T{ceil(x.real()), ceil(x.imag())}; } else { return ceil(x); } @@ -91,34 +83,23 @@ struct Ceil { }; struct Conjugate { - __device__ cuComplex operator()(cuComplex x) { - return {cuCrealf(x), -cuCimagf(x)}; + template + __device__ complex_t operator()(complex_t x) { + return conj(x); } }; struct Cos { template __device__ T operator()(T x) { - if constexpr (cuda::std::is_same_v) { - return { - cos(cuCrealf(x)) * cosh(cuCimagf(x)), - -sin(cuCrealf(x)) * sinh(cuCimagf(x))}; - } else { - return cos(x); - } + return cos(x); } }; struct Cosh { template __device__ T operator()(T x) { - if constexpr (cuda::std::is_same_v) { - return { - cosh(cuCrealf(x)) * cos(cuCimagf(x)), - sinh(cuCrealf(x)) * sin(cuCimagf(x))}; - } else { - return cosh(x); - } + return cosh(x); } }; @@ -151,12 +132,7 @@ struct ErfInv { struct Exp { template __device__ T operator()(T x) { - if constexpr (cuda::std::is_same_v) { - auto r = exp(cuda::std::complex{cuCrealf(x), cuCimagf(x)}); - return cuComplex{r.real(), r.imag()}; - } else { - return exp(x); - } + return exp(x); } }; @@ -178,6 +154,8 @@ struct Floor { __device__ T operator()(T x) { if constexpr (cuda::std::is_integral_v) { return x; + } else if constexpr (is_complex_v) { + return T{floor(x.real()), floor(x.imag())}; } else { return floor(x); } @@ -185,30 +163,25 @@ struct Floor { }; struct Imag { - __device__ float operator()(cuComplex x) { - return cuCimagf(x); + template + __device__ auto operator()(complex_t x) { + return x.imag(); } }; struct Log { template __device__ T operator()(T x) { - if constexpr (cuda::std::is_same_v) { - auto r = log(cuCrealf(Abs{}(x))); - auto i = atan2f(cuCimagf(x), cuCrealf(x)); - return {r, i}; - } else { - return log(x); - } + return log(x); } }; struct Log2 { template __device__ T operator()(T x) { - if constexpr (cuda::std::is_same_v) { + if constexpr (is_complex_v) { auto y = Log{}(x); - return {cuCrealf(y) / CUDART_LN2_F, cuCimagf(y) / CUDART_LN2_F}; + return {y.real() / CUDART_LN2_F, y.imag() / CUDART_LN2_F}; } else { return log2(x); } @@ -218,23 +191,17 @@ struct Log2 { struct Log10 { template __device__ T operator()(T x) { - if constexpr (cuda::std::is_same_v) { - auto y = Log{}(x); - return {cuCrealf(y) / CUDART_LNT_F, cuCimagf(y) / CUDART_LNT_F}; - return y; - } else { - return log10(x); - } + return log10(x); } }; struct Log1p { template __device__ T operator()(T z) { - if constexpr (cuda::std::is_same_v) { - float x = cuCrealf(z); - float y = cuCimagf(z); - float zabs = cuCrealf(Abs{}(z)); + if constexpr (is_complex_v) { + float x = z.real(); + float y = z.imag(); + float zabs = Abs{}(z).real(); float theta = atan2f(y, x + 1); if (zabs < 0.5f) { float r = x * (2 + x) + y * y; @@ -261,8 +228,8 @@ struct LogicalNot { struct Negative { template __device__ T operator()(T x) { - if constexpr (cuda::std::is_same_v) { - return 0 - x; + if constexpr (is_complex_v) { + return T{0, 0} - x; } else { return -x; } @@ -270,16 +237,17 @@ struct Negative { }; struct Real { - __device__ float operator()(cuComplex x) { - return cuCrealf(x); + template + __device__ auto operator()(complex_t x) { + return x.real(); } }; struct Round { template __device__ T operator()(T x) { - if constexpr (cuda::std::is_same_v) { - return {rint(cuCrealf(x)), rint(cuCimagf(x))}; + if constexpr (is_complex_v) { + return {rint(x.real()), rint(x.imag())}; } else { return rint(x); } @@ -299,8 +267,8 @@ struct Sign { __device__ T operator()(T x) { if constexpr (cuda::std::is_unsigned_v) { return x != 0; - } else if constexpr (cuda::std::is_same_v) { - if (cuCrealf(x) == 0 && cuCimagf(x) == 0) { + } else if constexpr (is_complex_v) { + if (x.real() == 0 && x.imag() == 0) { return x; } else { return x / Abs()(x); @@ -316,26 +284,14 @@ struct Sign { struct Sin { template __device__ T operator()(T x) { - if constexpr (cuda::std::is_same_v) { - return { - sin(cuCrealf(x)) * cosh(cuCimagf(x)), - cos(cuCrealf(x)) * sinh(cuCimagf(x))}; - } else { - return sin(x); - } + return sin(x); } }; struct Sinh { template __device__ T operator()(T x) { - if constexpr (cuda::std::is_same_v) { - return { - sinh(cuCrealf(x)) * cos(cuCimagf(x)), - cosh(cuCrealf(x)) * sin(cuCimagf(x))}; - } else { - return sinh(x); - } + return sinh(x); } }; @@ -351,77 +307,31 @@ struct Sqrt { __device__ T operator()(T x) { return sqrt(x); } - - __device__ cuComplex operator()(cuComplex x) { - auto xr = cuCrealf(x); - auto xi = cuCimagf(x); - if (xr == 0.0f && xi == 0.0f) { - return {0.0f, 0.0f}; - } - auto r = cuCrealf(Abs{}(x)); - auto a = sqrt((r + xr) / 2.0f); - auto b_abs = sqrt((r - xr) / 2.0f); - auto b = copysign(b_abs, xi); - return {a, b}; - } }; struct Rsqrt { template __device__ T operator()(T x) { - return rsqrt(x); - } - __device__ cuComplex operator()(cuComplex x) { - return 1.0f / Sqrt{}(x); + if constexpr (is_complex_v) { + return 1.0f / Sqrt{}(x); + } else { + return rsqrt(x); + } } }; struct Tan { template __device__ T operator()(T x) { - if constexpr (cuda::std::is_same_v) { - float tan_a = tan(cuCrealf(x)); - float tanh_b = tanh(cuCimagf(x)); - float t1 = tan_a * tanh_b; - float denom = 1. + t1 * t1; - return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom}; - } else { - return tan(x); - } + return tan(x); } }; struct Tanh { template __device__ T operator()(T x) { - if constexpr (cuda::std::is_same_v) { - float tanh_a = tanh(cuCrealf(x)); - float tan_b = tan(cuCimagf(x)); - float t1 = tanh_a * tan_b; - float denom = 1. + t1 * t1; - return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom}; - } else { - return tanh(x); - } + return tanh(x); } }; -inline __device__ cuComplex ArcCos::operator()(cuComplex x) { - auto i = cuComplex{0.0, 1.0}; - auto y = Log{}(x + i * Sqrt{}(1.0 - x * x)); - return {cuCimagf(y), -cuCrealf(y)}; -}; - -inline __device__ cuComplex ArcSin::operator()(cuComplex x) { - auto i = cuComplex{0.0f, 1.0f}; - auto y = Log{}(i * x + Sqrt{}(1.0f - x * x)); - return {cuCimagf(y), -cuCrealf(y)}; -}; - -inline __device__ cuComplex ArcTan::operator()(cuComplex x) { - auto i = cuComplex{0.0f, 1.0f}; - auto ix = i * x; - return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix)); -}; - } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index af022c141..73bc7ff63 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -8,9 +8,9 @@ #pragma once +#include "mlx/backend/cuda/device/complex.cuh" #include "mlx/backend/cuda/device/config.h" -#include #include #include #include @@ -127,13 +127,13 @@ struct Limits { } }; -template <> -struct Limits { - static constexpr __host__ __device__ cuComplex max() { - return {Limits::max(), Limits::max()}; +template +struct Limits> { + static constexpr __host__ __device__ complex_t max() { + return {Limits::max(), Limits::max()}; } - static constexpr __host__ __device__ cuComplex min() { - return {Limits::min(), Limits::min()}; + static constexpr __host__ __device__ complex_t min() { + return {Limits::min(), Limits::min()}; } }; diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 4ce79999e..343db902e 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -173,7 +173,7 @@ constexpr const char* g_include_names[] = { INCLUDE_PREFIX "binary_ops.cuh", INCLUDE_PREFIX "cast_op.cuh", INCLUDE_PREFIX "config.h", - INCLUDE_PREFIX "cucomplex_math.cuh", + INCLUDE_PREFIX "complex.cuh", INCLUDE_PREFIX "fp16_math.cuh", INCLUDE_PREFIX "indexing.cuh", INCLUDE_PREFIX "scatter_ops.cuh", @@ -189,7 +189,7 @@ constexpr const char* g_headers[] = { jit_source_binary_ops, jit_source_cast_op, jit_source_config, - jit_source_cucomplex_math, + jit_source_complex, jit_source_fp16_math, jit_source_indexing, jit_source_scatter_ops, diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index eeaf916c1..24c81f2fb 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -11,7 +11,6 @@ #include "mlx/array.h" #include "mlx/backend/cuda/device/utils.cuh" -#include #include #include #include @@ -79,7 +78,7 @@ struct CTypeToCudaType { template <> struct CTypeToCudaType { - using type = cuComplex; + using type = cu::complex64_t; }; template @@ -91,10 +90,14 @@ inline constexpr bool is_floating_v = cuda::std::is_same_v || cuda::std::is_same_v || cuda::std::is_same_v || cuda::std::is_same_v; +// Type traits for detecting complex numbers. +template +inline constexpr bool is_complex_v = cuda::std::is_same_v || + cuda::std::is_same_v; + // Type traits for detecting complex or real floating point numbers. template -inline constexpr bool is_inexact_v = - is_floating_v || cuda::std::is_same_v; +inline constexpr bool is_inexact_v = is_floating_v || is_complex_v; // Utility to copy data from vector to array in host. template diff --git a/mlx/backend/cuda/reduce/reduce.cuh b/mlx/backend/cuda/reduce/reduce.cuh index d0eb3f5c5..02e495594 100644 --- a/mlx/backend/cuda/reduce/reduce.cuh +++ b/mlx/backend/cuda/reduce/reduce.cuh @@ -3,7 +3,6 @@ #include #include "mlx/backend/common/reduce.h" -#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/reduce/reduce_ops.cuh" #include "mlx/dtype_utils.h" diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh index bc4dce33e..31ba90433 100644 --- a/mlx/backend/cuda/reduce/reduce_ops.cuh +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -151,7 +151,7 @@ struct ReduceInit { template struct ReduceInit { static constexpr __host__ __device__ auto value() { - if constexpr (cuda::std::is_same_v) { + if constexpr (is_complex_v) { return T{0, 0}; } else { return cast_to::type>(0); @@ -162,7 +162,7 @@ struct ReduceInit { template struct ReduceInit { static constexpr __host__ __device__ auto value() { - if constexpr (cuda::std::is_same_v) { + if constexpr (is_complex_v) { return T{1, 0}; } else { return cast_to::type>(1); diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 0d2754ef0..ddb32d05e 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -2,7 +2,6 @@ #include "mlx/backend/common/unary.h" #include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/device/unary_ops.cuh" #include "mlx/backend/cuda/iterators/general_iterator.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" @@ -71,10 +70,10 @@ constexpr bool supports_unary_op() { !std::is_same_v; } if (std::is_same_v || std::is_same_v) { - return std::is_same_v && !std::is_same_v; + return std::is_same_v && !mlx::core::is_complex_v; } if (std::is_same_v) { - return std::is_same_v && std::is_same_v; + return std::is_same_v && mlx::core::is_complex_v; } if (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || @@ -88,7 +87,7 @@ constexpr bool supports_unary_op() { return std::is_same_v && is_inexact_v; } if (std::is_same_v || std::is_same_v) { - return std::is_same_v && std::is_same_v; + return mlx::core::is_complex_v && std::is_same_v; } if (std::is_same_v) { return std::is_same_v && std::is_same_v; diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index cc05428a8..1c12fa4df 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -61,7 +61,7 @@ const char* dtype_to_cuda_type(const Dtype& dtype) { case float64: return "double"; case complex64: - return "cuComplex"; + return "complex64_t"; default: return "unknown"; }