mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 15:41:13 +08:00
[CUDA] Use cuda::std::complex in place of cuComplex (#2372)
This commit is contained in:
parent
f0a0b077a0
commit
cb349a291c
@ -3,7 +3,6 @@
|
|||||||
#include "mlx/backend/common/binary.h"
|
#include "mlx/backend/common/binary.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
#include "mlx/backend/common/binary.h"
|
#include "mlx/backend/common/binary.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
#include "mlx/backend/cuda/device/complex.cuh"
|
||||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
|
|
||||||
#include <cuda/atomic>
|
#include <cuda/atomic>
|
||||||
@ -48,7 +48,7 @@ inline __device__ void atomic_add(__half* out, __half val) {
|
|||||||
atomicAdd(out, val);
|
atomicAdd(out, val);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline __device__ void atomic_add(cuComplex* out, cuComplex val) {
|
inline __device__ void atomic_add(complex64_t* out, complex64_t val) {
|
||||||
#if __CUDA_ARCH__ < 900
|
#if __CUDA_ARCH__ < 900
|
||||||
atomic_add_general(out, val);
|
atomic_add_general(out, val);
|
||||||
#else
|
#else
|
||||||
|
@ -44,7 +44,7 @@ struct Remainder {
|
|||||||
} else {
|
} else {
|
||||||
return x % y;
|
return x % y;
|
||||||
}
|
}
|
||||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
} else if constexpr (is_complex_v<T>) {
|
||||||
return x % y;
|
return x % y;
|
||||||
} else {
|
} else {
|
||||||
T r = fmod(x, y);
|
T r = fmod(x, y);
|
||||||
@ -66,14 +66,12 @@ struct Equal {
|
|||||||
struct NaNEqual {
|
struct NaNEqual {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ bool operator()(T x, T y) {
|
__device__ bool operator()(T x, T y) {
|
||||||
if constexpr (std::is_same_v<T, cuComplex>) {
|
if constexpr (is_complex_v<T>) {
|
||||||
return x == y ||
|
return x == y ||
|
||||||
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) && isnan(cuCimagf(x)) &&
|
(isnan(x.real()) && isnan(y.real()) && isnan(x.imag()) &&
|
||||||
isnan(cuCimagf(y))) ||
|
isnan(y.imag())) ||
|
||||||
(cuCrealf(x) == cuCrealf(y) && isnan(cuCimagf(x)) &&
|
(x.real() == y.real() && isnan(x.imag()) && isnan(y.imag())) ||
|
||||||
isnan(cuCimagf(y))) ||
|
(isnan(x.real()) && isnan(y.real()) && x.imag() == y.imag());
|
||||||
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) &&
|
|
||||||
cuCimagf(x) == cuCimagf(y));
|
|
||||||
} else {
|
} else {
|
||||||
return x == y || (isnan(x) && isnan(y));
|
return x == y || (isnan(x) && isnan(y));
|
||||||
}
|
}
|
||||||
@ -111,17 +109,17 @@ struct LessEqual {
|
|||||||
struct LogAddExp {
|
struct LogAddExp {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x, T y) {
|
__device__ T operator()(T x, T y) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
if constexpr (is_complex_v<T>) {
|
||||||
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
|
if (isnan(x.real()) || isnan(x.imag()) || isnan(y.real()) ||
|
||||||
isnan(cuCimagf(y))) {
|
isnan(y.imag())) {
|
||||||
return {
|
return {
|
||||||
cuda::std::numeric_limits<float>::quiet_NaN(),
|
cuda::std::numeric_limits<float>::quiet_NaN(),
|
||||||
cuda::std::numeric_limits<float>::quiet_NaN()};
|
cuda::std::numeric_limits<float>::quiet_NaN()};
|
||||||
}
|
}
|
||||||
auto max = cuCrealf(x) > cuCrealf(y) ? x : y;
|
auto max = x.real() > y.real() ? x : y;
|
||||||
auto min = cuCrealf(x) < cuCrealf(y) ? x : y;
|
auto min = x.real() < y.real() ? x : y;
|
||||||
auto min_real = cuCrealf(min);
|
auto min_real = min.real();
|
||||||
auto max_real = cuCrealf(max);
|
auto max_real = max.real();
|
||||||
if (!isfinite(min_real) && (min_real == max_real)) {
|
if (!isfinite(min_real) && (min_real == max_real)) {
|
||||||
if (min_real < 0) {
|
if (min_real < 0) {
|
||||||
return min;
|
return min;
|
||||||
@ -150,8 +148,8 @@ struct Maximum {
|
|||||||
__device__ T operator()(T x, T y) {
|
__device__ T operator()(T x, T y) {
|
||||||
if constexpr (cuda::std::is_integral_v<T>) {
|
if constexpr (cuda::std::is_integral_v<T>) {
|
||||||
return max(x, y);
|
return max(x, y);
|
||||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
} else if constexpr (is_complex_v<T>) {
|
||||||
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
|
if (isnan(x.real()) || isnan(x.imag())) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
return x > y ? x : y;
|
return x > y ? x : y;
|
||||||
@ -169,8 +167,8 @@ struct Minimum {
|
|||||||
__device__ T operator()(T x, T y) {
|
__device__ T operator()(T x, T y) {
|
||||||
if constexpr (cuda::std::is_integral_v<T>) {
|
if constexpr (cuda::std::is_integral_v<T>) {
|
||||||
return min(x, y);
|
return min(x, y);
|
||||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
} else if constexpr (is_complex_v<T>) {
|
||||||
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
|
if (isnan(x.real()) || isnan(x.imag())) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
return x < y ? x : y;
|
return x < y ? x : y;
|
||||||
@ -193,8 +191,8 @@ struct Multiply {
|
|||||||
struct NotEqual {
|
struct NotEqual {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ bool operator()(T x, T y) {
|
__device__ bool operator()(T x, T y) {
|
||||||
if constexpr (std::is_same_v<T, cuComplex>) {
|
if constexpr (is_complex_v<T>) {
|
||||||
return cuCrealf(x) != cuCrealf(y) || cuCimagf(x) != cuCimagf(y);
|
return x.real() != y.real() || x.imag() != y.imag();
|
||||||
} else {
|
} else {
|
||||||
return x != y;
|
return x != y;
|
||||||
}
|
}
|
||||||
@ -214,19 +212,8 @@ struct Power {
|
|||||||
base *= base;
|
base *= base;
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
} else if constexpr (is_complex_v<T>) {
|
||||||
if (base.y == 0 && base.x == 0) {
|
return pow(base, exp);
|
||||||
if (isnan(exp.x) || isnan(exp.y)) {
|
|
||||||
auto nan = cuda::std::numeric_limits<float>::quiet_NaN();
|
|
||||||
return make_cuFloatComplex(nan, nan);
|
|
||||||
}
|
|
||||||
return make_cuFloatComplex(0.0, 0.0);
|
|
||||||
}
|
|
||||||
auto x_theta = atan2f(base.y, base.x);
|
|
||||||
auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y);
|
|
||||||
auto mag = expf(exp.x * x_ln_r - exp.y * x_theta);
|
|
||||||
auto phase = exp.y * x_ln_r + exp.x * x_theta;
|
|
||||||
return make_cuFloatComplex(mag * cosf(phase), mag * sinf(phase));
|
|
||||||
} else {
|
} else {
|
||||||
return powf(base, exp);
|
return powf(base, exp);
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <cuComplex.h>
|
#include "mlx/backend/cuda/device/complex.cuh"
|
||||||
|
|
||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <thrust/iterator/transform_iterator.h>
|
#include <thrust/iterator/transform_iterator.h>
|
||||||
@ -20,50 +21,43 @@ struct CastOp {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Castings between complex and boolean.
|
// Castings between complex and boolean.
|
||||||
// TODO: Should make a custom complex type.
|
template <typename T>
|
||||||
template <>
|
struct CastOp<complex_t<T>, bool> {
|
||||||
struct CastOp<cuComplex, bool> {
|
|
||||||
static constexpr bool is_castable = true;
|
static constexpr bool is_castable = true;
|
||||||
|
|
||||||
__device__ bool operator()(cuComplex x) {
|
__device__ bool operator()(complex_t<T> x) {
|
||||||
return x.x != 0 && x.y != 0;
|
return x.real() != 0 && x.imag() != 0;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <typename T>
|
||||||
struct CastOp<bool, cuComplex> {
|
struct CastOp<bool, complex_t<T>> {
|
||||||
static constexpr bool is_castable = true;
|
static constexpr bool is_castable = true;
|
||||||
|
|
||||||
__device__ cuComplex operator()(bool x) {
|
__device__ complex_t<T> operator()(bool x) {
|
||||||
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0);
|
return x ? complex_t<T>{1, 1} : complex_t<T>{0, 0};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Converting a complex number to real number discards the imaginary part.
|
// Converting a complex number to real number discards the imaginary part.
|
||||||
template <typename DstT>
|
template <typename T, typename DstT>
|
||||||
struct CastOp<
|
struct CastOp<complex_t<T>, DstT, cuda::std::enable_if_t<!is_complex_v<DstT>>> {
|
||||||
cuComplex,
|
static constexpr bool is_castable = cuda::std::is_convertible_v<T, DstT>;
|
||||||
DstT,
|
|
||||||
cuda::std::enable_if_t<!cuda::std::is_same_v<cuComplex, DstT>>> {
|
|
||||||
static constexpr bool is_castable = cuda::std::is_convertible_v<float, DstT>;
|
|
||||||
|
|
||||||
__device__ DstT operator()(cuComplex x) {
|
__device__ DstT operator()(complex_t<T> x) {
|
||||||
static_assert(!cuda::std::is_same_v<cuComplex, DstT>);
|
static_assert(!is_complex_v<DstT>);
|
||||||
return static_cast<DstT>(cuCrealf(x));
|
return static_cast<DstT>(x.real());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Allow converting a real number to complex number.
|
// Allow converting a real number to complex number.
|
||||||
template <typename SrcT>
|
template <typename SrcT, typename T>
|
||||||
struct CastOp<
|
struct CastOp<SrcT, complex_t<T>, cuda::std::enable_if_t<!is_complex_v<SrcT>>> {
|
||||||
SrcT,
|
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, T>;
|
||||||
cuComplex,
|
|
||||||
cuda::std::enable_if_t<!cuda::std::is_same_v<SrcT, cuComplex>>> {
|
|
||||||
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, float>;
|
|
||||||
|
|
||||||
__device__ cuComplex operator()(SrcT x) {
|
__device__ complex_t<T> operator()(SrcT x) {
|
||||||
static_assert(!cuda::std::is_same_v<SrcT, cuComplex>);
|
static_assert(!is_complex_v<SrcT>);
|
||||||
return cuComplex{static_cast<float>(x), 0};
|
return complex_t<T>{static_cast<T>(x), 0};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -88,8 +82,7 @@ struct CastOp<
|
|||||||
SrcT,
|
SrcT,
|
||||||
DstT,
|
DstT,
|
||||||
cuda::std::enable_if_t<
|
cuda::std::enable_if_t<
|
||||||
!cuda::std::is_convertible_v<SrcT, DstT> &&
|
!cuda::std::is_convertible_v<SrcT, DstT> && !is_complex_v<SrcT> &&
|
||||||
!cuda::std::is_same_v<SrcT, cuComplex> &&
|
|
||||||
(cuda::std::is_same_v<DstT, __half> ||
|
(cuda::std::is_same_v<DstT, __half> ||
|
||||||
cuda::std::is_same_v<DstT, __nv_bfloat16>)>> {
|
cuda::std::is_same_v<DstT, __nv_bfloat16>)>> {
|
||||||
static constexpr bool is_castable = true;
|
static constexpr bool is_castable = true;
|
||||||
@ -104,8 +97,7 @@ struct CastOp<
|
|||||||
SrcT,
|
SrcT,
|
||||||
DstT,
|
DstT,
|
||||||
cuda::std::enable_if_t<
|
cuda::std::enable_if_t<
|
||||||
!cuda::std::is_convertible_v<SrcT, DstT> &&
|
!cuda::std::is_convertible_v<SrcT, DstT> && !is_complex_v<SrcT> &&
|
||||||
!cuda::std::is_same_v<DstT, cuComplex> &&
|
|
||||||
!cuda::std::is_same_v<DstT, __half> &&
|
!cuda::std::is_same_v<DstT, __half> &&
|
||||||
!cuda::std::is_same_v<DstT, __nv_bfloat16> &&
|
!cuda::std::is_same_v<DstT, __nv_bfloat16> &&
|
||||||
(cuda::std::is_same_v<SrcT, __half> ||
|
(cuda::std::is_same_v<SrcT, __half> ||
|
||||||
|
61
mlx/backend/cuda/device/complex.cuh
Normal file
61
mlx/backend/cuda/device/complex.cuh
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
// Make multiplication and division faster.
|
||||||
|
#define LIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS
|
||||||
|
|
||||||
|
#include <cuda/std/complex>
|
||||||
|
#include <cuda/std/type_traits>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
// TODO: Consider using a faster implementation as cuda::std::complex has to
|
||||||
|
// conform to C++ standard.
|
||||||
|
template <typename T>
|
||||||
|
using complex_t = cuda::std::complex<T>;
|
||||||
|
|
||||||
|
using complex64_t = complex_t<float>;
|
||||||
|
using complex128_t = complex_t<double>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct is_complex : cuda::std::false_type {};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct is_complex<cuda::std::complex<T>> : cuda::std::true_type {};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline constexpr bool is_complex_v = is_complex<T>::value;
|
||||||
|
|
||||||
|
// cuda::std::complex is missing some operators.
|
||||||
|
template <typename T>
|
||||||
|
inline __host__ __device__ complex_t<T> operator%(
|
||||||
|
complex_t<T> a,
|
||||||
|
complex_t<T> b) {
|
||||||
|
T r = a.real() - floor(a.real() / b.real()) * b.real();
|
||||||
|
T i = a.imag() - floor(a.imag() / b.imag()) * b.imag();
|
||||||
|
return complex_t<T>{r, i};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline __host__ __device__ bool operator<(complex_t<T> a, complex_t<T> b) {
|
||||||
|
return (a.real() * a.real() + a.imag() * a.imag()) <
|
||||||
|
(b.real() * b.real() + b.imag() * b.imag());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline __host__ __device__ bool operator>(complex_t<T> a, complex_t<T> b) {
|
||||||
|
return b < a;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline __host__ __device__ bool operator<=(complex_t<T> a, complex_t<T> b) {
|
||||||
|
return !(a > b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline __host__ __device__ bool operator>=(complex_t<T> a, complex_t<T> b) {
|
||||||
|
return !(a < b);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
@ -1,240 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
// Copyright © 2017-2024 The Simons Foundation, Inc.
|
|
||||||
//
|
|
||||||
// FINUFFT is licensed under the Apache License, Version 2.0 (the
|
|
||||||
// "License"); you may not use this file except in compliance with the
|
|
||||||
// License. You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
//
|
|
||||||
// Forked from
|
|
||||||
// https://github.com/flatironinstitute/finufft/blob/main/include/cufinufft/contrib/helper_math.h
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <cuComplex.h>
|
|
||||||
|
|
||||||
// This header provides some helper functions for cuComplex types.
|
|
||||||
// It mainly wraps existing CUDA implementations to provide operator overloads
|
|
||||||
// e.g. cuAdd, cuSub, cuMul, cuDiv, cuCreal, cuCimag, cuCabs, cuCarg, cuConj are
|
|
||||||
// all provided by CUDA
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator+(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
|
||||||
return cuCadd(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator-(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
|
||||||
return cuCsub(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator*(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
|
||||||
return cuCmul(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator/(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
|
||||||
return cuCdiv(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator%(const cuDoubleComplex& a, const cuDoubleComplex& b) {
|
|
||||||
double r = cuCreal(a) - (floorf(cuCreal(a) / cuCreal(b)) * cuCreal(b));
|
|
||||||
double i = cuCimag(a) - (floorf(cuCimag(a) / cuCimag(b)) * cuCimag(b));
|
|
||||||
return make_cuDoubleComplex(r, i);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ bool operator==(
|
|
||||||
const cuDoubleComplex& a,
|
|
||||||
const cuDoubleComplex& b) {
|
|
||||||
return cuCreal(a) == cuCreal(b) && cuCimag(a) == cuCimag(b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ bool operator!=(
|
|
||||||
const cuDoubleComplex& a,
|
|
||||||
const cuDoubleComplex& b) {
|
|
||||||
return !(a == b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ bool operator>(
|
|
||||||
const cuDoubleComplex& a,
|
|
||||||
const cuDoubleComplex& b) {
|
|
||||||
double mag_a = sqrt(cuCreal(a) * cuCreal(a) + cuCimag(a) * cuCimag(a));
|
|
||||||
double mag_b = sqrt(cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b));
|
|
||||||
return mag_a > mag_b;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ bool operator>=(
|
|
||||||
const cuDoubleComplex& a,
|
|
||||||
const cuDoubleComplex& b) {
|
|
||||||
return a > b || a == b;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ bool operator<(
|
|
||||||
const cuDoubleComplex& a,
|
|
||||||
const cuDoubleComplex& b) {
|
|
||||||
return b > a;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ bool operator<=(
|
|
||||||
const cuDoubleComplex& a,
|
|
||||||
const cuDoubleComplex& b) {
|
|
||||||
return b > a || a == b;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator+(const cuDoubleComplex& a, double b) {
|
|
||||||
return make_cuDoubleComplex(cuCreal(a) + b, cuCimag(a));
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator+(double a, const cuDoubleComplex& b) {
|
|
||||||
return make_cuDoubleComplex(a + cuCreal(b), cuCimag(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator-(const cuDoubleComplex& a, double b) {
|
|
||||||
return make_cuDoubleComplex(cuCreal(a) - b, cuCimag(a));
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator-(double a, const cuDoubleComplex& b) {
|
|
||||||
return make_cuDoubleComplex(a - cuCreal(b), -cuCimag(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator*(const cuDoubleComplex& a, double b) {
|
|
||||||
return make_cuDoubleComplex(cuCreal(a) * b, cuCimag(a) * b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator*(double a, const cuDoubleComplex& b) {
|
|
||||||
return make_cuDoubleComplex(a * cuCreal(b), a * cuCimag(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator/(const cuDoubleComplex& a, double b) {
|
|
||||||
return make_cuDoubleComplex(cuCreal(a) / b, cuCimag(a) / b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuDoubleComplex
|
|
||||||
operator/(double a, const cuDoubleComplex& b) {
|
|
||||||
double denom = cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b);
|
|
||||||
return make_cuDoubleComplex(
|
|
||||||
(a * cuCreal(b)) / denom, (-a * cuCimag(b)) / denom);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator+(const cuFloatComplex& a, const cuFloatComplex& b) {
|
|
||||||
return cuCaddf(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator-(const cuFloatComplex& a, const cuFloatComplex& b) {
|
|
||||||
return cuCsubf(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator*(const cuFloatComplex& a, const cuFloatComplex& b) {
|
|
||||||
return cuCmulf(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator/(const cuFloatComplex& a, const cuFloatComplex& b) {
|
|
||||||
return cuCdivf(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator%(const cuFloatComplex& a, const cuFloatComplex& b) {
|
|
||||||
float r = cuCrealf(a) - (floorf(cuCrealf(a) / cuCrealf(b)) * cuCrealf(b));
|
|
||||||
float i = cuCimagf(a) - (floorf(cuCimagf(a) / cuCimagf(b)) * cuCimagf(b));
|
|
||||||
return make_cuFloatComplex(r, i);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ bool operator==(
|
|
||||||
const cuFloatComplex& a,
|
|
||||||
const cuFloatComplex& b) {
|
|
||||||
return cuCrealf(a) == cuCrealf(b) && cuCimagf(a) == cuCimagf(b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ bool operator!=(
|
|
||||||
const cuFloatComplex& a,
|
|
||||||
const cuFloatComplex& b) {
|
|
||||||
return !(a == b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ bool operator>(
|
|
||||||
const cuFloatComplex& a,
|
|
||||||
const cuFloatComplex& b) {
|
|
||||||
float mag_a = sqrt(cuCrealf(a) * cuCrealf(a) + cuCimagf(a) * cuCimagf(a));
|
|
||||||
float mag_b = sqrt(cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b));
|
|
||||||
return mag_a > mag_b;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ bool operator>=(
|
|
||||||
const cuFloatComplex& a,
|
|
||||||
const cuFloatComplex& b) {
|
|
||||||
return a > b || a == b;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ bool operator<(
|
|
||||||
const cuFloatComplex& a,
|
|
||||||
const cuFloatComplex& b) {
|
|
||||||
return b > a;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ bool operator<=(
|
|
||||||
const cuFloatComplex& a,
|
|
||||||
const cuFloatComplex& b) {
|
|
||||||
return b > a || a == b;
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator+(const cuFloatComplex& a, float b) {
|
|
||||||
return make_cuFloatComplex(cuCrealf(a) + b, cuCimagf(a));
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator+(float a, const cuFloatComplex& b) {
|
|
||||||
return make_cuFloatComplex(a + cuCrealf(b), cuCimagf(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator-(const cuFloatComplex& a, float b) {
|
|
||||||
return make_cuFloatComplex(cuCrealf(a) - b, cuCimagf(a));
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator-(float a, const cuFloatComplex& b) {
|
|
||||||
return make_cuFloatComplex(a - cuCrealf(b), -cuCimagf(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator*(const cuFloatComplex& a, float b) {
|
|
||||||
return make_cuFloatComplex(cuCrealf(a) * b, cuCimagf(a) * b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator*(float a, const cuFloatComplex& b) {
|
|
||||||
return make_cuFloatComplex(a * cuCrealf(b), a * cuCimagf(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator/(const cuFloatComplex& a, float b) {
|
|
||||||
return make_cuFloatComplex(cuCrealf(a) / b, cuCimagf(a) / b);
|
|
||||||
}
|
|
||||||
|
|
||||||
__forceinline__ __host__ __device__ cuFloatComplex
|
|
||||||
operator/(float a, const cuFloatComplex& b) {
|
|
||||||
float denom = cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b);
|
|
||||||
return make_cuFloatComplex(
|
|
||||||
(a * cuCrealf(b)) / denom, (-a * cuCimagf(b)) / denom);
|
|
||||||
}
|
|
@ -2,12 +2,10 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
|
||||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
#include <math_constants.h>
|
#include <math_constants.h>
|
||||||
#include <cuda/std/complex>
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
@ -16,8 +14,6 @@ struct Abs {
|
|||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_unsigned_v<T>) {
|
if constexpr (cuda::std::is_unsigned_v<T>) {
|
||||||
return x;
|
return x;
|
||||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
|
||||||
return {sqrt(cuCrealf(x) * cuCrealf(x) + cuCimagf(x) * cuCimagf(x)), 0};
|
|
||||||
} else {
|
} else {
|
||||||
return abs(x);
|
return abs(x);
|
||||||
}
|
}
|
||||||
@ -29,8 +25,6 @@ struct ArcCos {
|
|||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
return acos(x);
|
return acos(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ cuComplex operator()(cuComplex x);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ArcCosh {
|
struct ArcCosh {
|
||||||
@ -45,8 +39,6 @@ struct ArcSin {
|
|||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
return asin(x);
|
return asin(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ cuComplex operator()(cuComplex x);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ArcSinh {
|
struct ArcSinh {
|
||||||
@ -61,8 +53,6 @@ struct ArcTan {
|
|||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
return atan(x);
|
return atan(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ cuComplex operator()(cuComplex x);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ArcTanh {
|
struct ArcTanh {
|
||||||
@ -84,6 +74,8 @@ struct Ceil {
|
|||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_integral_v<T>) {
|
if constexpr (cuda::std::is_integral_v<T>) {
|
||||||
return x;
|
return x;
|
||||||
|
} else if constexpr (is_complex_v<T>) {
|
||||||
|
return T{ceil(x.real()), ceil(x.imag())};
|
||||||
} else {
|
} else {
|
||||||
return ceil(x);
|
return ceil(x);
|
||||||
}
|
}
|
||||||
@ -91,34 +83,23 @@ struct Ceil {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct Conjugate {
|
struct Conjugate {
|
||||||
__device__ cuComplex operator()(cuComplex x) {
|
template <typename T>
|
||||||
return {cuCrealf(x), -cuCimagf(x)};
|
__device__ complex_t<T> operator()(complex_t<T> x) {
|
||||||
|
return conj(x);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Cos {
|
struct Cos {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
return cos(x);
|
||||||
return {
|
|
||||||
cos(cuCrealf(x)) * cosh(cuCimagf(x)),
|
|
||||||
-sin(cuCrealf(x)) * sinh(cuCimagf(x))};
|
|
||||||
} else {
|
|
||||||
return cos(x);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Cosh {
|
struct Cosh {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
return cosh(x);
|
||||||
return {
|
|
||||||
cosh(cuCrealf(x)) * cos(cuCimagf(x)),
|
|
||||||
sinh(cuCrealf(x)) * sin(cuCimagf(x))};
|
|
||||||
} else {
|
|
||||||
return cosh(x);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -151,12 +132,7 @@ struct ErfInv {
|
|||||||
struct Exp {
|
struct Exp {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
return exp(x);
|
||||||
auto r = exp(cuda::std::complex<float>{cuCrealf(x), cuCimagf(x)});
|
|
||||||
return cuComplex{r.real(), r.imag()};
|
|
||||||
} else {
|
|
||||||
return exp(x);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -178,6 +154,8 @@ struct Floor {
|
|||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_integral_v<T>) {
|
if constexpr (cuda::std::is_integral_v<T>) {
|
||||||
return x;
|
return x;
|
||||||
|
} else if constexpr (is_complex_v<T>) {
|
||||||
|
return T{floor(x.real()), floor(x.imag())};
|
||||||
} else {
|
} else {
|
||||||
return floor(x);
|
return floor(x);
|
||||||
}
|
}
|
||||||
@ -185,30 +163,25 @@ struct Floor {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct Imag {
|
struct Imag {
|
||||||
__device__ float operator()(cuComplex x) {
|
template <typename T>
|
||||||
return cuCimagf(x);
|
__device__ auto operator()(complex_t<T> x) {
|
||||||
|
return x.imag();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Log {
|
struct Log {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
return log(x);
|
||||||
auto r = log(cuCrealf(Abs{}(x)));
|
|
||||||
auto i = atan2f(cuCimagf(x), cuCrealf(x));
|
|
||||||
return {r, i};
|
|
||||||
} else {
|
|
||||||
return log(x);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Log2 {
|
struct Log2 {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
if constexpr (is_complex_v<T>) {
|
||||||
auto y = Log{}(x);
|
auto y = Log{}(x);
|
||||||
return {cuCrealf(y) / CUDART_LN2_F, cuCimagf(y) / CUDART_LN2_F};
|
return {y.real() / CUDART_LN2_F, y.imag() / CUDART_LN2_F};
|
||||||
} else {
|
} else {
|
||||||
return log2(x);
|
return log2(x);
|
||||||
}
|
}
|
||||||
@ -218,23 +191,17 @@ struct Log2 {
|
|||||||
struct Log10 {
|
struct Log10 {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
return log10(x);
|
||||||
auto y = Log{}(x);
|
|
||||||
return {cuCrealf(y) / CUDART_LNT_F, cuCimagf(y) / CUDART_LNT_F};
|
|
||||||
return y;
|
|
||||||
} else {
|
|
||||||
return log10(x);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Log1p {
|
struct Log1p {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T z) {
|
__device__ T operator()(T z) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
if constexpr (is_complex_v<T>) {
|
||||||
float x = cuCrealf(z);
|
float x = z.real();
|
||||||
float y = cuCimagf(z);
|
float y = z.imag();
|
||||||
float zabs = cuCrealf(Abs{}(z));
|
float zabs = Abs{}(z).real();
|
||||||
float theta = atan2f(y, x + 1);
|
float theta = atan2f(y, x + 1);
|
||||||
if (zabs < 0.5f) {
|
if (zabs < 0.5f) {
|
||||||
float r = x * (2 + x) + y * y;
|
float r = x * (2 + x) + y * y;
|
||||||
@ -261,8 +228,8 @@ struct LogicalNot {
|
|||||||
struct Negative {
|
struct Negative {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
if constexpr (is_complex_v<T>) {
|
||||||
return 0 - x;
|
return T{0, 0} - x;
|
||||||
} else {
|
} else {
|
||||||
return -x;
|
return -x;
|
||||||
}
|
}
|
||||||
@ -270,16 +237,17 @@ struct Negative {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct Real {
|
struct Real {
|
||||||
__device__ float operator()(cuComplex x) {
|
template <typename T>
|
||||||
return cuCrealf(x);
|
__device__ auto operator()(complex_t<T> x) {
|
||||||
|
return x.real();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Round {
|
struct Round {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
if constexpr (is_complex_v<T>) {
|
||||||
return {rint(cuCrealf(x)), rint(cuCimagf(x))};
|
return {rint(x.real()), rint(x.imag())};
|
||||||
} else {
|
} else {
|
||||||
return rint(x);
|
return rint(x);
|
||||||
}
|
}
|
||||||
@ -299,8 +267,8 @@ struct Sign {
|
|||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_unsigned_v<T>) {
|
if constexpr (cuda::std::is_unsigned_v<T>) {
|
||||||
return x != 0;
|
return x != 0;
|
||||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
} else if constexpr (is_complex_v<T>) {
|
||||||
if (cuCrealf(x) == 0 && cuCimagf(x) == 0) {
|
if (x.real() == 0 && x.imag() == 0) {
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
return x / Abs()(x);
|
return x / Abs()(x);
|
||||||
@ -316,26 +284,14 @@ struct Sign {
|
|||||||
struct Sin {
|
struct Sin {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
return sin(x);
|
||||||
return {
|
|
||||||
sin(cuCrealf(x)) * cosh(cuCimagf(x)),
|
|
||||||
cos(cuCrealf(x)) * sinh(cuCimagf(x))};
|
|
||||||
} else {
|
|
||||||
return sin(x);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Sinh {
|
struct Sinh {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
return sinh(x);
|
||||||
return {
|
|
||||||
sinh(cuCrealf(x)) * cos(cuCimagf(x)),
|
|
||||||
cosh(cuCrealf(x)) * sin(cuCimagf(x))};
|
|
||||||
} else {
|
|
||||||
return sinh(x);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -351,77 +307,31 @@ struct Sqrt {
|
|||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
return sqrt(x);
|
return sqrt(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ cuComplex operator()(cuComplex x) {
|
|
||||||
auto xr = cuCrealf(x);
|
|
||||||
auto xi = cuCimagf(x);
|
|
||||||
if (xr == 0.0f && xi == 0.0f) {
|
|
||||||
return {0.0f, 0.0f};
|
|
||||||
}
|
|
||||||
auto r = cuCrealf(Abs{}(x));
|
|
||||||
auto a = sqrt((r + xr) / 2.0f);
|
|
||||||
auto b_abs = sqrt((r - xr) / 2.0f);
|
|
||||||
auto b = copysign(b_abs, xi);
|
|
||||||
return {a, b};
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Rsqrt {
|
struct Rsqrt {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
return rsqrt(x);
|
if constexpr (is_complex_v<T>) {
|
||||||
}
|
return 1.0f / Sqrt{}(x);
|
||||||
__device__ cuComplex operator()(cuComplex x) {
|
} else {
|
||||||
return 1.0f / Sqrt{}(x);
|
return rsqrt(x);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Tan {
|
struct Tan {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
return tan(x);
|
||||||
float tan_a = tan(cuCrealf(x));
|
|
||||||
float tanh_b = tanh(cuCimagf(x));
|
|
||||||
float t1 = tan_a * tanh_b;
|
|
||||||
float denom = 1. + t1 * t1;
|
|
||||||
return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
|
|
||||||
} else {
|
|
||||||
return tan(x);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Tanh {
|
struct Tanh {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
return tanh(x);
|
||||||
float tanh_a = tanh(cuCrealf(x));
|
|
||||||
float tan_b = tan(cuCimagf(x));
|
|
||||||
float t1 = tanh_a * tan_b;
|
|
||||||
float denom = 1. + t1 * t1;
|
|
||||||
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
|
|
||||||
} else {
|
|
||||||
return tanh(x);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
inline __device__ cuComplex ArcCos::operator()(cuComplex x) {
|
|
||||||
auto i = cuComplex{0.0, 1.0};
|
|
||||||
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
|
|
||||||
return {cuCimagf(y), -cuCrealf(y)};
|
|
||||||
};
|
|
||||||
|
|
||||||
inline __device__ cuComplex ArcSin::operator()(cuComplex x) {
|
|
||||||
auto i = cuComplex{0.0f, 1.0f};
|
|
||||||
auto y = Log{}(i * x + Sqrt{}(1.0f - x * x));
|
|
||||||
return {cuCimagf(y), -cuCrealf(y)};
|
|
||||||
};
|
|
||||||
|
|
||||||
inline __device__ cuComplex ArcTan::operator()(cuComplex x) {
|
|
||||||
auto i = cuComplex{0.0f, 1.0f};
|
|
||||||
auto ix = i * x;
|
|
||||||
return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix));
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core::cu
|
||||||
|
@ -8,9 +8,9 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device/complex.cuh"
|
||||||
#include "mlx/backend/cuda/device/config.h"
|
#include "mlx/backend/cuda/device/config.h"
|
||||||
|
|
||||||
#include <cuComplex.h>
|
|
||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <cuda/std/array>
|
#include <cuda/std/array>
|
||||||
@ -127,13 +127,13 @@ struct Limits<bool> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <typename T>
|
||||||
struct Limits<cuComplex> {
|
struct Limits<complex_t<T>> {
|
||||||
static constexpr __host__ __device__ cuComplex max() {
|
static constexpr __host__ __device__ complex_t<T> max() {
|
||||||
return {Limits<float>::max(), Limits<float>::max()};
|
return {Limits<T>::max(), Limits<T>::max()};
|
||||||
}
|
}
|
||||||
static constexpr __host__ __device__ cuComplex min() {
|
static constexpr __host__ __device__ complex_t<T> min() {
|
||||||
return {Limits<float>::min(), Limits<float>::min()};
|
return {Limits<T>::min(), Limits<T>::min()};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -173,7 +173,7 @@ constexpr const char* g_include_names[] = {
|
|||||||
INCLUDE_PREFIX "binary_ops.cuh",
|
INCLUDE_PREFIX "binary_ops.cuh",
|
||||||
INCLUDE_PREFIX "cast_op.cuh",
|
INCLUDE_PREFIX "cast_op.cuh",
|
||||||
INCLUDE_PREFIX "config.h",
|
INCLUDE_PREFIX "config.h",
|
||||||
INCLUDE_PREFIX "cucomplex_math.cuh",
|
INCLUDE_PREFIX "complex.cuh",
|
||||||
INCLUDE_PREFIX "fp16_math.cuh",
|
INCLUDE_PREFIX "fp16_math.cuh",
|
||||||
INCLUDE_PREFIX "indexing.cuh",
|
INCLUDE_PREFIX "indexing.cuh",
|
||||||
INCLUDE_PREFIX "scatter_ops.cuh",
|
INCLUDE_PREFIX "scatter_ops.cuh",
|
||||||
@ -189,7 +189,7 @@ constexpr const char* g_headers[] = {
|
|||||||
jit_source_binary_ops,
|
jit_source_binary_ops,
|
||||||
jit_source_cast_op,
|
jit_source_cast_op,
|
||||||
jit_source_config,
|
jit_source_config,
|
||||||
jit_source_cucomplex_math,
|
jit_source_complex,
|
||||||
jit_source_fp16_math,
|
jit_source_fp16_math,
|
||||||
jit_source_indexing,
|
jit_source_indexing,
|
||||||
jit_source_scatter_ops,
|
jit_source_scatter_ops,
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
#include <cuComplex.h>
|
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
@ -79,7 +78,7 @@ struct CTypeToCudaType<bfloat16_t> {
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct CTypeToCudaType<complex64_t> {
|
struct CTypeToCudaType<complex64_t> {
|
||||||
using type = cuComplex;
|
using type = cu::complex64_t;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -91,10 +90,14 @@ inline constexpr bool is_floating_v =
|
|||||||
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
|
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
|
||||||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
|
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
|
||||||
|
|
||||||
|
// Type traits for detecting complex numbers.
|
||||||
|
template <typename T>
|
||||||
|
inline constexpr bool is_complex_v = cuda::std::is_same_v<T, complex64_t> ||
|
||||||
|
cuda::std::is_same_v<T, complex128_t>;
|
||||||
|
|
||||||
// Type traits for detecting complex or real floating point numbers.
|
// Type traits for detecting complex or real floating point numbers.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline constexpr bool is_inexact_v =
|
inline constexpr bool is_inexact_v = is_floating_v<T> || is_complex_v<T>;
|
||||||
is_floating_v<T> || cuda::std::is_same_v<T, complex64_t>;
|
|
||||||
|
|
||||||
// Utility to copy data from vector to array in host.
|
// Utility to copy data from vector to array in host.
|
||||||
template <int NDIM = MAX_NDIM, typename T = int32_t>
|
template <int NDIM = MAX_NDIM, typename T = int32_t>
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
|
||||||
#include "mlx/backend/common/reduce.h"
|
#include "mlx/backend/common/reduce.h"
|
||||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/cuda/reduce/reduce_ops.cuh"
|
#include "mlx/backend/cuda/reduce/reduce_ops.cuh"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
|
@ -151,7 +151,7 @@ struct ReduceInit<Or, T> {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
struct ReduceInit<Sum, T> {
|
struct ReduceInit<Sum, T> {
|
||||||
static constexpr __host__ __device__ auto value() {
|
static constexpr __host__ __device__ auto value() {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
if constexpr (is_complex_v<T>) {
|
||||||
return T{0, 0};
|
return T{0, 0};
|
||||||
} else {
|
} else {
|
||||||
return cast_to<typename ReduceResult<Sum, T>::type>(0);
|
return cast_to<typename ReduceResult<Sum, T>::type>(0);
|
||||||
@ -162,7 +162,7 @@ struct ReduceInit<Sum, T> {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
struct ReduceInit<Prod, T> {
|
struct ReduceInit<Prod, T> {
|
||||||
static constexpr __host__ __device__ auto value() {
|
static constexpr __host__ __device__ auto value() {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
if constexpr (is_complex_v<T>) {
|
||||||
return T{1, 0};
|
return T{1, 0};
|
||||||
} else {
|
} else {
|
||||||
return cast_to<typename ReduceResult<Prod, T>::type>(1);
|
return cast_to<typename ReduceResult<Prod, T>::type>(1);
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
#include "mlx/backend/common/unary.h"
|
#include "mlx/backend/common/unary.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
|
||||||
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
||||||
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
|
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
@ -71,10 +70,10 @@ constexpr bool supports_unary_op() {
|
|||||||
!std::is_same_v<In, bool>;
|
!std::is_same_v<In, bool>;
|
||||||
}
|
}
|
||||||
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor>) {
|
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor>) {
|
||||||
return std::is_same_v<In, Out> && !std::is_same_v<In, complex64_t>;
|
return std::is_same_v<In, Out> && !mlx::core::is_complex_v<In>;
|
||||||
}
|
}
|
||||||
if (std::is_same_v<Op, Conjugate>) {
|
if (std::is_same_v<Op, Conjugate>) {
|
||||||
return std::is_same_v<In, Out> && std::is_same_v<In, complex64_t>;
|
return std::is_same_v<In, Out> && mlx::core::is_complex_v<In>;
|
||||||
}
|
}
|
||||||
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcSin> ||
|
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcSin> ||
|
||||||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, Cos> ||
|
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, Cos> ||
|
||||||
@ -88,7 +87,7 @@ constexpr bool supports_unary_op() {
|
|||||||
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||||
}
|
}
|
||||||
if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {
|
if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {
|
||||||
return std::is_same_v<In, complex64_t> && std::is_same_v<Out, float>;
|
return mlx::core::is_complex_v<In> && std::is_same_v<Out, float>;
|
||||||
}
|
}
|
||||||
if (std::is_same_v<Op, LogicalNot>) {
|
if (std::is_same_v<Op, LogicalNot>) {
|
||||||
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
|
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
|
||||||
|
@ -61,7 +61,7 @@ const char* dtype_to_cuda_type(const Dtype& dtype) {
|
|||||||
case float64:
|
case float64:
|
||||||
return "double";
|
return "double";
|
||||||
case complex64:
|
case complex64:
|
||||||
return "cuComplex";
|
return "complex64_t";
|
||||||
default:
|
default:
|
||||||
return "unknown";
|
return "unknown";
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user