MLX
Loading...
Searching...
No Matches
complex.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include <metal_stdlib>
6
7using namespace metal;
8
9struct complex64_t;
10
11template <typename T>
12static constexpr constant bool can_convert_to_complex64 =
13 !is_same_v<T, complex64_t> && is_convertible_v<T, float>;
14
15template <typename T>
16static constexpr constant bool can_convert_from_complex64 =
17 !is_same_v<T, complex64_t> &&
18 (is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>);
19
21 float real;
22 float imag;
23
24 // Constructors
25 constexpr complex64_t(float real, float imag) : real(real), imag(imag) {};
26 constexpr complex64_t() : real(0), imag(0) {};
27 constexpr complex64_t() threadgroup : real(0), imag(0) {};
28
29 // Conversions to complex64_t
30 template <
31 typename T,
32 typename = typename enable_if<can_convert_to_complex64<T>>::type>
33 constexpr complex64_t(T x) thread : real(x), imag(0) {}
34
35 template <
36 typename T,
37 typename = typename enable_if<can_convert_to_complex64<T>>::type>
38 constexpr complex64_t(T x) threadgroup : real(x), imag(0) {}
39
40 template <
41 typename T,
42 typename = typename enable_if<can_convert_to_complex64<T>>::type>
43 constexpr complex64_t(T x) device : real(x), imag(0) {}
44
45 template <
46 typename T,
47 typename = typename enable_if<can_convert_to_complex64<T>>::type>
48 constexpr complex64_t(T x) constant : real(x), imag(0) {}
49
50 // Conversions from complex64_t
51 template <
52 typename T,
53 typename = typename enable_if<can_convert_from_complex64<T>>::type>
54 constexpr operator T() const thread {
55 return static_cast<T>(real);
56 }
57
58 template <
59 typename T,
60 typename = typename enable_if<can_convert_from_complex64<T>>::type>
61 constexpr operator T() const threadgroup {
62 return static_cast<T>(real);
63 }
64
65 template <
66 typename T,
67 typename = typename enable_if<can_convert_from_complex64<T>>::type>
68 constexpr operator T() const device {
69 return static_cast<T>(real);
70 }
71
72 template <
73 typename T,
74 typename = typename enable_if<can_convert_from_complex64<T>>::type>
75 constexpr operator T() const constant {
76 return static_cast<T>(real);
77 }
78};
79
81 return {-x.real, -x.imag};
82}
83
84constexpr bool operator>=(complex64_t a, complex64_t b) {
85 return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag);
86}
87
88constexpr bool operator>(complex64_t a, complex64_t b) {
89 return (a.real > b.real) || (a.real == b.real && a.imag > b.imag);
90}
91
92constexpr bool operator<=(complex64_t a, complex64_t b) {
93 return operator>=(b, a);
94}
95
96constexpr bool operator<(complex64_t a, complex64_t b) {
97 return operator>(b, a);
98}
99
100constexpr bool operator==(complex64_t a, complex64_t b) {
101 return a.real == b.real && a.imag == b.imag;
102}
103
105 return {a.real + b.real, a.imag + b.imag};
106}
107
109 return {a.real - b.real, a.imag - b.imag};
110}
111
113 return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
114}
115
117 auto denom = b.real * b.real + b.imag * b.imag;
118 auto x = a.real * b.real + a.imag * b.imag;
119 auto y = a.imag * b.real - a.real * b.imag;
120 return {x / denom, y / denom};
121}
122
124 auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
125 auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
126 if (real != 0 && (real < 0 != b.real < 0)) {
127 real += b.real;
128 }
129 if (imag != 0 && (imag < 0 != b.imag < 0)) {
130 imag += b.imag;
131 }
132 return {real, imag};
133}
constexpr bool operator>(complex64_t a, complex64_t b)
Definition complex.h:88
constexpr complex64_t operator-(complex64_t x)
Definition complex.h:80
static constexpr constant bool can_convert_to_complex64
Definition complex.h:12
constexpr bool operator<(complex64_t a, complex64_t b)
Definition complex.h:96
constexpr complex64_t operator*(complex64_t a, complex64_t b)
Definition complex.h:112
constexpr complex64_t operator%(complex64_t a, complex64_t b)
Definition complex.h:123
constexpr bool operator>=(complex64_t a, complex64_t b)
Definition complex.h:84
static constexpr constant bool can_convert_from_complex64
Definition complex.h:16
constexpr bool operator==(complex64_t a, complex64_t b)
Definition complex.h:100
constexpr complex64_t operator+(complex64_t a, complex64_t b)
Definition complex.h:104
constexpr complex64_t operator/(complex64_t a, complex64_t b)
Definition complex.h:116
constexpr bool operator<=(complex64_t a, complex64_t b)
Definition complex.h:92
Definition bf16_math.h:226
Definition complex.h:20
constexpr complex64_t(T x) const ant
Definition complex.h:48
constexpr complex64_t()
Definition complex.h:26
constexpr complex64_t(T x) thread
Definition complex.h:33
constexpr complex64_t(T x) threadgroup
Definition complex.h:38
constexpr complex64_t() threadgroup
Definition complex.h:27
float imag
Definition complex.h:22
float real
Definition complex.h:21
constexpr complex64_t(T x) device
Definition complex.h:43
constexpr complex64_t(float real, float imag)
Definition complex.h:25