Files
mlx/mlx/types/complex.h
2023-11-30 11:12:53 -08:00

78 lines
2.8 KiB
C++

// Copyright © 2023 Apple Inc.
#pragma once
#include <complex>
#include "mlx/types/half_types.h"
namespace mlx::core {
struct complex64_t;
template <typename T>
static constexpr bool can_convert_to_complex64 =
!std::is_same_v<T, complex64_t> && std::is_convertible_v<T, float>;
struct complex64_t : public std::complex<float> {
complex64_t(float v, float u) : std::complex<float>(v, u){};
complex64_t(std::complex<float> v) : std::complex<float>(v){};
template <
typename T,
typename = typename std::enable_if<can_convert_to_complex64<T>>::type>
complex64_t(T x) : std::complex<float>(x){};
operator float() const {
return real();
};
};
inline bool operator>=(const complex64_t& a, const complex64_t& b) {
return (a.real() > b.real()) ||
(a.real() == b.real() && a.imag() >= b.imag());
}
inline bool operator>(const complex64_t& a, const complex64_t& b) {
return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag());
}
inline bool operator<=(const complex64_t& a, const complex64_t& b) {
return operator>=(b, a);
}
inline bool operator<(const complex64_t& a, const complex64_t& b) {
return operator>(b, a);
}
inline complex64_t operator-(const complex64_t& v) {
return -static_cast<std::complex<float>>(v);
}
// clang-format off
#define complex_binop_helper(_op_, _operator_, itype) \
inline complex64_t _operator_(itype x, const complex64_t& y) { \
return x _op_ static_cast<std::complex<float>>(y); \
} \
inline complex64_t _operator_(const complex64_t& x, itype y) { \
return static_cast<std::complex<float>>(x) _op_ y; \
}
#define complex_binop(_op_, _operator_) \
inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \
return static_cast<std::complex<float>>(x) \
_op_ static_cast<std::complex<float>>(y); \
} \
complex_binop_helper(_op_, _operator_, bool) \
complex_binop_helper(_op_, _operator_, uint32_t) \
complex_binop_helper(_op_, _operator_, uint64_t) \
complex_binop_helper(_op_, _operator_, int32_t) \
complex_binop_helper(_op_, _operator_, int64_t) \
complex_binop_helper(_op_, _operator_, float16_t) \
complex_binop_helper(_op_, _operator_, bfloat16_t) \
complex_binop_helper(_op_, _operator_, const std::complex<float>&) \
complex_binop_helper(_op_, _operator_, float)
// clang-format on
complex_binop(+, operator+)
} // namespace mlx::core