MLX
Loading...
Searching...
No Matches
complex.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include <metal_stdlib>
6
7using namespace metal;
8
9struct complex64_t;
10
11template <typename T>
12static constexpr constant bool can_convert_to_complex64 =
13 !is_same_v<T, complex64_t> && is_convertible_v<T, float>;
14
15template <typename T>
16static constexpr constant bool can_convert_from_complex64 =
17 !is_same_v<T, complex64_t> &&
18 (is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>);
19
21 float real;
22 float imag;
23
24 // Constructors
25 constexpr complex64_t(float real, float imag) : real(real), imag(imag) {};
26
27 // Conversions to complex64_t
28 template <
29 typename T,
30 typename = typename enable_if<can_convert_to_complex64<T>>::type>
31 constexpr complex64_t(T x) thread : real(x), imag(0) {}
32
33 template <
34 typename T,
35 typename = typename enable_if<can_convert_to_complex64<T>>::type>
36 constexpr complex64_t(T x) threadgroup : real(x), imag(0) {}
37
38 template <
39 typename T,
40 typename = typename enable_if<can_convert_to_complex64<T>>::type>
41 constexpr complex64_t(T x) device : real(x), imag(0) {}
42
43 template <
44 typename T,
45 typename = typename enable_if<can_convert_to_complex64<T>>::type>
46 constexpr complex64_t(T x) constant : real(x), imag(0) {}
47
48 // Conversions from complex64_t
49 template <
50 typename T,
51 typename = typename enable_if<can_convert_from_complex64<T>>::type>
52 constexpr operator T() const thread {
53 return static_cast<T>(real);
54 }
55
56 template <
57 typename T,
58 typename = typename enable_if<can_convert_from_complex64<T>>::type>
59 constexpr operator T() const threadgroup {
60 return static_cast<T>(real);
61 }
62
63 template <
64 typename T,
65 typename = typename enable_if<can_convert_from_complex64<T>>::type>
66 constexpr operator T() const device {
67 return static_cast<T>(real);
68 }
69
70 template <
71 typename T,
72 typename = typename enable_if<can_convert_from_complex64<T>>::type>
73 constexpr operator T() const constant {
74 return static_cast<T>(real);
75 }
76};
77
79 return {-x.real, -x.imag};
80}
81
82constexpr bool operator>=(complex64_t a, complex64_t b) {
83 return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag);
84}
85
86constexpr bool operator>(complex64_t a, complex64_t b) {
87 return (a.real > b.real) || (a.real == b.real && a.imag > b.imag);
88}
89
90constexpr bool operator<=(complex64_t a, complex64_t b) {
91 return operator>=(b, a);
92}
93
94constexpr bool operator<(complex64_t a, complex64_t b) {
95 return operator>(b, a);
96}
97
98constexpr bool operator==(complex64_t a, complex64_t b) {
99 return a.real == b.real && a.imag == b.imag;
100}
101
103 return {a.real + b.real, a.imag + b.imag};
104}
105
107 return {a.real - b.real, a.imag - b.imag};
108}
109
111 return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
112}
113
115 auto denom = b.real * b.real + b.imag * b.imag;
116 auto x = a.real * b.real + a.imag * b.imag;
117 auto y = a.imag * b.real - a.real * b.imag;
118 return {x / denom, y / denom};
119}
120
122 auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
123 auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
124 if (real != 0 && (real < 0 != b.real < 0)) {
125 real += b.real;
126 }
127 if (imag != 0 && (imag < 0 != b.imag < 0)) {
128 imag += b.imag;
129 }
130 return {real, imag};
131}
constexpr bool operator>(complex64_t a, complex64_t b)
Definition complex.h:86
constexpr complex64_t operator-(complex64_t x)
Definition complex.h:78
static constexpr constant bool can_convert_to_complex64
Definition complex.h:12
constexpr bool operator<(complex64_t a, complex64_t b)
Definition complex.h:94
constexpr complex64_t operator*(complex64_t a, complex64_t b)
Definition complex.h:110
constexpr complex64_t operator%(complex64_t a, complex64_t b)
Definition complex.h:121
constexpr bool operator>=(complex64_t a, complex64_t b)
Definition complex.h:82
static constexpr constant bool can_convert_from_complex64
Definition complex.h:16
constexpr bool operator==(complex64_t a, complex64_t b)
Definition complex.h:98
constexpr complex64_t operator+(complex64_t a, complex64_t b)
Definition complex.h:102
constexpr complex64_t operator/(complex64_t a, complex64_t b)
Definition complex.h:114
constexpr bool operator<=(complex64_t a, complex64_t b)
Definition complex.h:90
Definition bf16.h:265
Definition complex.h:20
constexpr complex64_t(T x) const ant
Definition complex.h:46
constexpr complex64_t(T x) thread
Definition complex.h:31
constexpr complex64_t(T x) threadgroup
Definition complex.h:36
float imag
Definition complex.h:22
float real
Definition complex.h:21
constexpr complex64_t(T x) device
Definition complex.h:41
constexpr complex64_t(float real, float imag)
Definition complex.h:25