mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
78 lines
2.8 KiB
C++
78 lines
2.8 KiB
C++
// Copyright © 2023 Apple Inc.
|
|
|
|
#pragma once
|
|
#include <complex>
|
|
#include "mlx/types/half_types.h"
|
|
|
|
namespace mlx::core {
|
|
|
|
struct complex64_t;
|
|
|
|
template <typename T>
|
|
static constexpr bool can_convert_to_complex64 =
|
|
!std::is_same_v<T, complex64_t> && std::is_convertible_v<T, float>;
|
|
|
|
struct complex64_t : public std::complex<float> {
|
|
complex64_t(float v, float u) : std::complex<float>(v, u){};
|
|
complex64_t(std::complex<float> v) : std::complex<float>(v){};
|
|
|
|
template <
|
|
typename T,
|
|
typename = typename std::enable_if<can_convert_to_complex64<T>>::type>
|
|
complex64_t(T x) : std::complex<float>(x){};
|
|
|
|
operator float() const {
|
|
return real();
|
|
};
|
|
};
|
|
|
|
inline bool operator>=(const complex64_t& a, const complex64_t& b) {
|
|
return (a.real() > b.real()) ||
|
|
(a.real() == b.real() && a.imag() >= b.imag());
|
|
}
|
|
|
|
inline bool operator>(const complex64_t& a, const complex64_t& b) {
|
|
return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag());
|
|
}
|
|
|
|
inline bool operator<=(const complex64_t& a, const complex64_t& b) {
|
|
return operator>=(b, a);
|
|
}
|
|
|
|
inline bool operator<(const complex64_t& a, const complex64_t& b) {
|
|
return operator>(b, a);
|
|
}
|
|
|
|
inline complex64_t operator-(const complex64_t& v) {
|
|
return -static_cast<std::complex<float>>(v);
|
|
}
|
|
|
|
// clang-format off
|
|
#define complex_binop_helper(_op_, _operator_, itype) \
|
|
inline complex64_t _operator_(itype x, const complex64_t& y) { \
|
|
return x _op_ static_cast<std::complex<float>>(y); \
|
|
} \
|
|
inline complex64_t _operator_(const complex64_t& x, itype y) { \
|
|
return static_cast<std::complex<float>>(x) _op_ y; \
|
|
}
|
|
|
|
#define complex_binop(_op_, _operator_) \
|
|
inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \
|
|
return static_cast<std::complex<float>>(x) \
|
|
_op_ static_cast<std::complex<float>>(y); \
|
|
} \
|
|
complex_binop_helper(_op_, _operator_, bool) \
|
|
complex_binop_helper(_op_, _operator_, uint32_t) \
|
|
complex_binop_helper(_op_, _operator_, uint64_t) \
|
|
complex_binop_helper(_op_, _operator_, int32_t) \
|
|
complex_binop_helper(_op_, _operator_, int64_t) \
|
|
complex_binop_helper(_op_, _operator_, float16_t) \
|
|
complex_binop_helper(_op_, _operator_, bfloat16_t) \
|
|
complex_binop_helper(_op_, _operator_, const std::complex<float>&) \
|
|
complex_binop_helper(_op_, _operator_, float)
|
|
// clang-format on
|
|
|
|
complex_binop(+, operator+)
|
|
|
|
} // namespace mlx::core
|