MLX
 
Loading...
Searching...
No Matches
binary_ops.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <metal_integer>
6#include <metal_math>
7
8struct Add {
9 template <typename T>
10 T operator()(T x, T y) {
11 return x + y;
12 }
13};
14
16 template <typename T>
17 T operator()(T x, T y) {
18 return x / y;
19 }
20 template <>
21 float operator()(float x, float y) {
22 return trunc(x / y);
23 }
24 template <>
25 half operator()(half x, half y) {
26 return trunc(x / y);
27 }
28 template <>
30 return trunc(x / y);
31 }
32};
33
34struct Divide {
35 template <typename T>
36 T operator()(T x, T y) {
37 return x / y;
38 }
39};
40
41struct Remainder {
42 template <typename T>
43 metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
44 operator()(T x, T y) {
45 return x % y;
46 }
47 template <typename T>
48 metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
49 operator()(T x, T y) {
50 auto r = x % y;
51 if (r != 0 && (r < 0 != y < 0)) {
52 r += y;
53 }
54 return r;
55 }
56 template <typename T>
57 metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
58 T r = fmod(x, y);
59 if (r != 0 && (r < 0 != y < 0)) {
60 r += y;
61 }
62 return r;
63 }
64 template <>
66 return x % y;
67 }
68};
69
70struct Equal {
71 template <typename T>
72 bool operator()(T x, T y) {
73 return x == y;
74 }
75};
76
77struct NaNEqual {
78 template <typename T>
79 bool operator()(T x, T y) {
80 return x == y || (metal::isnan(x) && metal::isnan(y));
81 }
82 template <>
84 return x == y ||
86 metal::isnan(y.imag)) ||
87 (x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
88 (metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
89 }
90};
91
92struct Greater {
93 template <typename T>
94 bool operator()(T x, T y) {
95 return x > y;
96 }
97};
98
100 template <typename T>
101 bool operator()(T x, T y) {
102 return x >= y;
103 }
104};
105
106struct Less {
107 template <typename T>
108 bool operator()(T x, T y) {
109 return x < y;
110 }
111};
112
113struct LessEqual {
114 template <typename T>
115 bool operator()(T x, T y) {
116 return x <= y;
117 }
118};
119
120struct LogAddExp {
121 template <typename T>
122 T operator()(T x, T y) {
123 if (metal::isnan(x) || metal::isnan(y)) {
124 return metal::numeric_limits<T>::quiet_NaN();
125 }
126 constexpr T inf = metal::numeric_limits<T>::infinity();
127 T maxval = metal::max(x, y);
128 T minval = metal::min(x, y);
129 return (minval == -inf || maxval == inf)
130 ? maxval
131 : (maxval + log1p(metal::exp(minval - maxval)));
132 };
133};
134
135struct Maximum {
136 template <typename T>
137 metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
138 return metal::max(x, y);
139 }
140
141 template <typename T>
142 metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
143 if (metal::isnan(x)) {
144 return x;
145 }
146 return x > y ? x : y;
147 }
148
149 template <>
151 if (metal::isnan(x.real) || metal::isnan(x.imag)) {
152 return x;
153 }
154 return x > y ? x : y;
155 }
156};
157
158struct Minimum {
159 template <typename T>
160 metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
161 return metal::min(x, y);
162 }
163
164 template <typename T>
165 metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
166 if (metal::isnan(x)) {
167 return x;
168 }
169 return x < y ? x : y;
170 }
171
172 template <>
174 if (metal::isnan(x.real) || metal::isnan(x.imag)) {
175 return x;
176 }
177 return x < y ? x : y;
178 }
179};
180
181struct Multiply {
182 template <typename T>
183 T operator()(T x, T y) {
184 return x * y;
185 }
186};
187
188struct NotEqual {
189 template <typename T>
190 bool operator()(T x, T y) {
191 return x != y;
192 }
193 template <>
195 return x.real != y.real || x.imag != y.imag;
196 }
197};
198
199struct Power {
200 template <typename T>
201 metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
202 return metal::pow(base, exp);
203 }
204
205 template <typename T>
206 metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
207 T res = 1;
208 while (exp) {
209 if (exp & 1) {
210 res *= base;
211 }
212 exp >>= 1;
213 base *= base;
214 }
215 return res;
216 }
217
218 template <>
220 auto x_theta = metal::atan2(x.imag, x.real);
221 auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
222 auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
223 auto phase = y.imag * x_ln_r + y.real * x_theta;
224 return {mag * metal::cos(phase), mag * metal::sin(phase)};
225 }
226};
227
228struct Subtract {
229 template <typename T>
230 T operator()(T x, T y) {
231 return x - y;
232 }
233};
234
236 template <typename T>
237 T operator()(T x, T y) {
238 return x && y;
239 };
240};
241
242struct LogicalOr {
243 template <typename T>
244 T operator()(T x, T y) {
245 return x || y;
246 };
247};
248
250 template <typename T>
251 T operator()(T x, T y) {
252 return x & y;
253 };
254};
255
256struct BitwiseOr {
257 template <typename T>
258 T operator()(T x, T y) {
259 return x | y;
260 };
261};
262
264 template <typename T>
265 T operator()(T x, T y) {
266 return x ^ y;
267 };
268};
269
270struct LeftShift {
271 template <typename T>
272 T operator()(T x, T y) {
273 return x << y;
274 };
275};
276
278 template <typename T>
279 T operator()(T x, T y) {
280 return x >> y;
281 };
282};
283
284struct ArcTan2 {
285 template <typename T>
286 T operator()(T y, T x) {
287 return metal::precise::atan2(y, x);
288 }
289};
290
291struct DivMod {
292 template <typename T>
293 metal::array<T, 2> operator()(T x, T y) {
294 return {FloorDivide{}(x, y), Remainder{}(x, y)};
295 };
296};
struct _MLX_BFloat16 bfloat16_t
Definition bf16.h:251
float log1p(float x)
Definition utils.h:307
METAL_FUNC bfloat16_t atan2(bfloat16_t y, bfloat16_t x)
Definition bf16_math.h:250
METAL_FUNC bfloat16_t atan2(bfloat16_t y, bfloat16_t x)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t cos(bfloat16_t x)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t log(bfloat16_t x)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t sin(bfloat16_t x)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t min(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
METAL_FUNC bool isnan(_MLX_BFloat16 x)
Definition bf16.h:301
METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t exp(bfloat16_t x)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t pow(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
Definition binary_ops.h:8
T operator()(T x, T y)
Definition binary_ops.h:10
Definition binary_ops.h:284
T operator()(T y, T x)
Definition binary_ops.h:286
Definition binary_ops.h:249
T operator()(T x, T y)
Definition binary_ops.h:251
Definition binary_ops.h:256
T operator()(T x, T y)
Definition binary_ops.h:258
Definition binary_ops.h:263
T operator()(T x, T y)
Definition binary_ops.h:265
Definition binary_ops.h:291
metal::array< T, 2 > operator()(T x, T y)
Definition binary_ops.h:293
Definition binary_ops.h:34
T operator()(T x, T y)
Definition binary_ops.h:36
Definition binary_ops.h:70
bool operator()(T x, T y)
Definition binary_ops.h:72
Definition binary_ops.h:15
T operator()(T x, T y)
Definition binary_ops.h:17
bfloat16_t operator()(bfloat16_t x, bfloat16_t y)
Definition binary_ops.h:29
half operator()(half x, half y)
Definition binary_ops.h:25
float operator()(float x, float y)
Definition binary_ops.h:21
Definition binary_ops.h:99
bool operator()(T x, T y)
Definition binary_ops.h:101
Definition binary_ops.h:92
bool operator()(T x, T y)
Definition binary_ops.h:94
Definition binary_ops.h:270
T operator()(T x, T y)
Definition binary_ops.h:272
Definition binary_ops.h:113
bool operator()(T x, T y)
Definition binary_ops.h:115
Definition binary_ops.h:106
bool operator()(T x, T y)
Definition binary_ops.h:108
Definition binary_ops.h:120
T operator()(T x, T y)
Definition binary_ops.h:122
Definition binary_ops.h:235
T operator()(T x, T y)
Definition binary_ops.h:237
Definition binary_ops.h:242
T operator()(T x, T y)
Definition binary_ops.h:244
Definition binary_ops.h:135
metal::enable_if_t<!metal::is_integral_v< T >, T > operator()(T x, T y)
Definition binary_ops.h:142
metal::enable_if_t< metal::is_integral_v< T >, T > operator()(T x, T y)
Definition binary_ops.h:137
complex64_t operator()(complex64_t x, complex64_t y)
Definition binary_ops.h:150
Definition binary_ops.h:158
metal::enable_if_t<!metal::is_integral_v< T >, T > operator()(T x, T y)
Definition binary_ops.h:165
complex64_t operator()(complex64_t x, complex64_t y)
Definition binary_ops.h:173
metal::enable_if_t< metal::is_integral_v< T >, T > operator()(T x, T y)
Definition binary_ops.h:160
Definition binary_ops.h:181
T operator()(T x, T y)
Definition binary_ops.h:183
Definition binary_ops.h:77
bool operator()(T x, T y)
Definition binary_ops.h:79
bool operator()(complex64_t x, complex64_t y)
Definition binary_ops.h:83
Definition binary_ops.h:188
bool operator()(complex64_t x, complex64_t y)
Definition binary_ops.h:194
bool operator()(T x, T y)
Definition binary_ops.h:190
Definition binary_ops.h:199
complex64_t operator()(complex64_t x, complex64_t y)
Definition binary_ops.h:219
metal::enable_if_t<!metal::is_integral_v< T >, T > operator()(T base, T exp)
Definition binary_ops.h:201
metal::enable_if_t< metal::is_integral_v< T >, T > operator()(T base, T exp)
Definition binary_ops.h:206
Definition binary_ops.h:41
metal::enable_if_t< metal::is_integral_v< T > &metal::is_signed_v< T >, T > operator()(T x, T y)
Definition binary_ops.h:49
metal::enable_if_t<!metal::is_integral_v< T >, T > operator()(T x, T y)
Definition binary_ops.h:57
metal::enable_if_t< metal::is_integral_v< T > &!metal::is_signed_v< T >, T > operator()(T x, T y)
Definition binary_ops.h:44
complex64_t operator()(complex64_t x, complex64_t y)
Definition binary_ops.h:65
Definition binary_ops.h:277
T operator()(T x, T y)
Definition binary_ops.h:279
Definition binary_ops.h:228
T operator()(T x, T y)
Definition binary_ops.h:230
Definition complex.h:20
float imag
Definition complex.h:22
float real
Definition complex.h:21