MLX
Loading...
Searching...
No Matches
binary.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <metal_integer>
6#include <metal_math>
7
10
11struct Add {
12 template <typename T>
13 T operator()(T x, T y) {
14 return x + y;
15 }
16};
17
18struct Divide {
19 template <typename T>
20 T operator()(T x, T y) {
21 return x / y;
22 }
23};
24
25struct Remainder {
26 template <typename T>
27 metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
28 operator()(T x, T y) {
29 return x % y;
30 }
31 template <typename T>
32 metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
33 operator()(T x, T y) {
34 auto r = x % y;
35 if (r != 0 && (r < 0 != y < 0)) {
36 r += y;
37 }
38 return r;
39 }
40 template <typename T>
41 metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
42 T r = fmod(x, y);
43 if (r != 0 && (r < 0 != y < 0)) {
44 r += y;
45 }
46 return r;
47 }
48 template <>
50 return x % y;
51 }
52};
53
54struct Equal {
55 template <typename T>
56 bool operator()(T x, T y) {
57 return x == y;
58 }
59};
60
61struct NaNEqual {
62 template <typename T>
63 bool operator()(T x, T y) {
64 return x == y || (metal::isnan(x) && metal::isnan(y));
65 }
66 template <>
68 return x == y ||
70 metal::isnan(y.imag)) ||
71 (x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
72 (metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
73 }
74};
75
76struct Greater {
77 template <typename T>
78 bool operator()(T x, T y) {
79 return x > y;
80 }
81};
82
84 template <typename T>
85 bool operator()(T x, T y) {
86 return x >= y;
87 }
88};
89
90struct Less {
91 template <typename T>
92 bool operator()(T x, T y) {
93 return x < y;
94 }
95};
96
97struct LessEqual {
98 template <typename T>
99 bool operator()(T x, T y) {
100 return x <= y;
101 }
102};
103
104struct LogAddExp {
105 template <typename T>
106 T operator()(T x, T y) {
107 if (metal::isnan(x) || metal::isnan(y)) {
108 return metal::numeric_limits<T>::quiet_NaN();
109 }
110 constexpr T inf = metal::numeric_limits<T>::infinity();
111 T maxval = metal::max(x, y);
112 T minval = metal::min(x, y);
113 return (minval == -inf || maxval == inf)
114 ? maxval
115 : (maxval + log1p(metal::exp(minval - maxval)));
116 };
117};
118
119struct Maximum {
120 template <typename T>
121 metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
122 return metal::max(x, y);
123 }
124
125 template <typename T>
126 metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
127 if (metal::isnan(x)) {
128 return x;
129 }
130 return x > y ? x : y;
131 }
132
133 template <>
135 if (metal::isnan(x.real) || metal::isnan(x.imag)) {
136 return x;
137 }
138 return x > y ? x : y;
139 }
140};
141
142struct Minimum {
143 template <typename T>
144 metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
145 return metal::min(x, y);
146 }
147
148 template <typename T>
149 metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
150 if (metal::isnan(x)) {
151 return x;
152 }
153 return x < y ? x : y;
154 }
155
156 template <>
158 if (metal::isnan(x.real) || metal::isnan(x.imag)) {
159 return x;
160 }
161 return x < y ? x : y;
162 }
163};
164
165struct Multiply {
166 template <typename T>
167 T operator()(T x, T y) {
168 return x * y;
169 }
170};
171
172struct NotEqual {
173 template <typename T>
174 bool operator()(T x, T y) {
175 return x != y;
176 }
177 template <>
179 return x.real != y.real || x.imag != y.imag;
180 }
181};
182
183struct Power {
184 template <typename T>
185 metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
186 return metal::pow(base, exp);
187 }
188
189 template <typename T>
190 metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
191 T res = 1;
192 while (exp) {
193 if (exp & 1) {
194 res *= base;
195 }
196 exp >>= 1;
197 base *= base;
198 }
199 return res;
200 }
201
202 template <>
204 auto x_theta = metal::atan(x.imag / x.real);
205 auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
206 auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
207 auto phase = y.imag * x_ln_r + y.real * x_theta;
208 return {mag * metal::cos(phase), mag * metal::sin(phase)};
209 }
210};
211
212struct Subtract {
213 template <typename T>
214 T operator()(T x, T y) {
215 return x - y;
216 }
217};
218
220 template <typename T>
221 T operator()(T x, T y) {
222 return x && y;
223 };
224};
225
226struct LogicalOr {
227 template <typename T>
228 T operator()(T x, T y) {
229 return x || y;
230 };
231};
232
234 template <typename T>
235 T operator()(T x, T y) {
236 return x & y;
237 };
238};
239
240struct BitwiseOr {
241 template <typename T>
242 T operator()(T x, T y) {
243 return x | y;
244 };
245};
246
248 template <typename T>
249 T operator()(T x, T y) {
250 return x ^ y;
251 };
252};
253
254struct LeftShift {
255 template <typename T>
256 T operator()(T x, T y) {
257 return x << y;
258 };
259};
260
262 template <typename T>
263 T operator()(T x, T y) {
264 return x >> y;
265 };
266};
267
268struct ArcTan2 {
269 template <typename T>
270 T operator()(T y, T x) {
271 return metal::precise::atan2(y, x);
272 }
273};
float log1p(float x)
Definition utils.h:298
METAL_FUNC bfloat16_t atan2(bfloat16_t y, bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t cos(bfloat16_t x)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t fmod(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t log(bfloat16_t x)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t sin(bfloat16_t x)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t min(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t atan(bfloat16_t y_over_x)
Definition bf16_math.h:234
METAL_FUNC bool isnan(_MLX_BFloat16 x)
Definition bf16.h:307
METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t exp(bfloat16_t x)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t pow(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:234
Definition binary.h:11
T operator()(T x, T y)
Definition binary.h:13
Definition binary.h:268
T operator()(T y, T x)
Definition binary.h:270
Definition binary.h:233
T operator()(T x, T y)
Definition binary.h:235
Definition binary.h:240
T operator()(T x, T y)
Definition binary.h:242
Definition binary.h:247
T operator()(T x, T y)
Definition binary.h:249
Definition binary.h:18
T operator()(T x, T y)
Definition binary.h:20
Definition binary.h:54
bool operator()(T x, T y)
Definition binary.h:56
Definition binary.h:83
bool operator()(T x, T y)
Definition binary.h:85
Definition binary.h:76
bool operator()(T x, T y)
Definition binary.h:78
Definition binary.h:254
T operator()(T x, T y)
Definition binary.h:256
Definition binary.h:97
bool operator()(T x, T y)
Definition binary.h:99
Definition binary.h:90
bool operator()(T x, T y)
Definition binary.h:92
Definition binary.h:104
T operator()(T x, T y)
Definition binary.h:106
Definition binary.h:219
T operator()(T x, T y)
Definition binary.h:221
Definition binary.h:226
T operator()(T x, T y)
Definition binary.h:228
Definition binary.h:119
metal::enable_if_t<!metal::is_integral_v< T >, T > operator()(T x, T y)
Definition binary.h:126
metal::enable_if_t< metal::is_integral_v< T >, T > operator()(T x, T y)
Definition binary.h:121
complex64_t operator()(complex64_t x, complex64_t y)
Definition binary.h:134
Definition binary.h:142
metal::enable_if_t<!metal::is_integral_v< T >, T > operator()(T x, T y)
Definition binary.h:149
complex64_t operator()(complex64_t x, complex64_t y)
Definition binary.h:157
metal::enable_if_t< metal::is_integral_v< T >, T > operator()(T x, T y)
Definition binary.h:144
Definition binary.h:165
T operator()(T x, T y)
Definition binary.h:167
Definition binary.h:61
bool operator()(T x, T y)
Definition binary.h:63
bool operator()(complex64_t x, complex64_t y)
Definition binary.h:67
Definition binary.h:172
bool operator()(complex64_t x, complex64_t y)
Definition binary.h:178
bool operator()(T x, T y)
Definition binary.h:174
Definition binary.h:183
complex64_t operator()(complex64_t x, complex64_t y)
Definition binary.h:203
metal::enable_if_t<!metal::is_integral_v< T >, T > operator()(T base, T exp)
Definition binary.h:185
metal::enable_if_t< metal::is_integral_v< T >, T > operator()(T base, T exp)
Definition binary.h:190
Definition binary.h:25
metal::enable_if_t< metal::is_integral_v< T > &metal::is_signed_v< T >, T > operator()(T x, T y)
Definition binary.h:33
metal::enable_if_t<!metal::is_integral_v< T >, T > operator()(T x, T y)
Definition binary.h:41
metal::enable_if_t< metal::is_integral_v< T > &!metal::is_signed_v< T >, T > operator()(T x, T y)
Definition binary.h:28
complex64_t operator()(complex64_t x, complex64_t y)
Definition binary.h:49
Definition binary.h:261
T operator()(T x, T y)
Definition binary.h:263
Definition binary.h:212
T operator()(T x, T y)
Definition binary.h:214
Definition complex.h:20
float imag
Definition complex.h:22
float real
Definition complex.h:21