MLX
Loading...
Searching...
No Matches
unary_ops.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <metal_integer>
6#include <metal_math>
7
10
11namespace {
12constant float inf = metal::numeric_limits<float>::infinity();
13}
14
15struct Abs {
16 template <typename T>
17 T operator()(T x) {
18 return metal::abs(x);
19 };
20 template <>
21 uint8_t operator()(uint8_t x) {
22 return x;
23 };
24 template <>
25 uint16_t operator()(uint16_t x) {
26 return x;
27 };
28 template <>
29 uint32_t operator()(uint32_t x) {
30 return x;
31 };
32 template <>
33 uint64_t operator()(uint64_t x) {
34 return x;
35 };
36 template <>
37 bool operator()(bool x) {
38 return x;
39 };
40 template <>
42 return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
43 };
44};
45
46struct ArcCos {
47 template <typename T>
48 T operator()(T x) {
49 return metal::precise::acos(x);
50 };
51};
52
53struct ArcCosh {
54 template <typename T>
55 T operator()(T x) {
56 return metal::precise::acosh(x);
57 };
58};
59
60struct ArcSin {
61 template <typename T>
62 T operator()(T x) {
63 return metal::precise::asin(x);
64 };
65};
66
67struct ArcSinh {
68 template <typename T>
69 T operator()(T x) {
70 return metal::precise::asinh(x);
71 };
72};
73
74struct ArcTan {
75 template <typename T>
76 T operator()(T x) {
77 return metal::precise::atan(x);
78 };
79};
80
81struct ArcTanh {
82 template <typename T>
83 T operator()(T x) {
84 return metal::precise::atanh(x);
85 };
86};
87
88struct Ceil {
89 template <typename T>
90 T operator()(T x) {
91 return metal::ceil(x);
92 };
93 template <>
94 int8_t operator()(int8_t x) {
95 return x;
96 };
97 template <>
98 int16_t operator()(int16_t x) {
99 return x;
100 };
101 template <>
102 int32_t operator()(int32_t x) {
103 return x;
104 };
105 template <>
106 int64_t operator()(int64_t x) {
107 return x;
108 };
109 template <>
110 uint8_t operator()(uint8_t x) {
111 return x;
112 };
113 template <>
114 uint16_t operator()(uint16_t x) {
115 return x;
116 };
117 template <>
118 uint32_t operator()(uint32_t x) {
119 return x;
120 };
121 template <>
122 uint64_t operator()(uint64_t x) {
123 return x;
124 };
125 template <>
126 bool operator()(bool x) {
127 return x;
128 };
129};
130
131struct Cos {
132 template <typename T>
133 T operator()(T x) {
134 return metal::precise::cos(x);
135 };
136
137 template <>
143};
144
145struct Cosh {
146 template <typename T>
147 T operator()(T x) {
148 return metal::precise::cosh(x);
149 };
150
151 template <>
157};
158
159struct Conjugate {
161 return complex64_t{x.real, -x.imag};
162 }
163};
164
165struct Erf {
166 template <typename T>
167 T operator()(T x) {
168 return static_cast<T>(erf(static_cast<float>(x)));
169 };
170};
171
172struct ErfInv {
173 template <typename T>
174 T operator()(T x) {
175 return static_cast<T>(erfinv(static_cast<float>(x)));
176 };
177};
178
179struct Exp {
180 template <typename T>
181 T operator()(T x) {
182 return metal::precise::exp(x);
183 };
184 template <>
189};
190
191struct Expm1 {
192 template <typename T>
193 T operator()(T x) {
194 return static_cast<T>(expm1f(static_cast<float>(x)));
195 };
196};
197
198struct Floor {
199 template <typename T>
200 T operator()(T x) {
201 return metal::floor(x);
202 };
203 template <>
204 int8_t operator()(int8_t x) {
205 return x;
206 };
207 template <>
208 int16_t operator()(int16_t x) {
209 return x;
210 };
211 template <>
212 int32_t operator()(int32_t x) {
213 return x;
214 };
215 template <>
216 int64_t operator()(int64_t x) {
217 return x;
218 };
219 template <>
220 uint8_t operator()(uint8_t x) {
221 return x;
222 };
223 template <>
224 uint16_t operator()(uint16_t x) {
225 return x;
226 };
227 template <>
228 uint32_t operator()(uint32_t x) {
229 return x;
230 };
231 template <>
232 uint64_t operator()(uint64_t x) {
233 return x;
234 };
235 template <>
236 bool operator()(bool x) {
237 return x;
238 };
239};
240
241struct Log {
242 template <typename T>
243 T operator()(T x) {
244 return metal::precise::log(x);
245 };
246};
247
248struct Log2 {
249 template <typename T>
250 T operator()(T x) {
251 return metal::precise::log2(x);
252 };
253};
254
255struct Log10 {
256 template <typename T>
257 T operator()(T x) {
258 return metal::precise::log10(x);
259 };
260};
261
262struct Log1p {
263 template <typename T>
264 T operator()(T x) {
265 return log1p(x);
266 };
267};
268
270 template <typename T>
271 T operator()(T x) {
272 return !x;
273 };
274};
275
276struct Negative {
277 template <typename T>
278 T operator()(T x) {
279 return -x;
280 };
281};
282
283struct Round {
284 template <typename T>
285 T operator()(T x) {
286 return metal::rint(x);
287 };
288 template <>
292};
293
294struct Sigmoid {
295 template <typename T>
296 T operator()(T x) {
297 auto y = 1 / (1 + metal::exp(-metal::abs(x)));
298 return (x < 0) ? 1 - y : y;
299 }
300};
301
302struct Sign {
303 template <typename T>
304 T operator()(T x) {
305 return (x > T(0)) - (x < T(0));
306 };
307 template <>
308 uint32_t operator()(uint32_t x) {
309 return x != 0;
310 };
311 template <>
313 if (x == complex64_t(0)) {
314 return x;
315 }
316 return x /
318 };
319};
320
321struct Sin {
322 template <typename T>
323 T operator()(T x) {
324 return metal::precise::sin(x);
325 };
326
327 template <>
333};
334
335struct Sinh {
336 template <typename T>
337 T operator()(T x) {
338 return metal::precise::sinh(x);
339 };
340
341 template <>
347};
348
349struct Square {
350 template <typename T>
351 T operator()(T x) {
352 return x * x;
353 };
354};
355
356struct Sqrt {
357 template <typename T>
358 T operator()(T x) {
359 return metal::precise::sqrt(x);
360 };
361};
362
363struct Rsqrt {
364 template <typename T>
365 T operator()(T x) {
366 return metal::precise::rsqrt(x);
367 };
368};
369
370struct Tan {
371 template <typename T>
372 T operator()(T x) {
373 return metal::precise::tan(x);
374 };
375
376 template <>
378 float tan_a = metal::precise::tan(x.real);
379 float tanh_b = metal::precise::tanh(x.imag);
380 float t1 = tan_a * tanh_b;
381 float denom = 1. + t1 * t1;
382 return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
383 };
384};
385
386struct Tanh {
387 template <typename T>
388 T operator()(T x) {
389 return metal::precise::tanh(x);
390 };
391
392 template <>
394 float tanh_a = metal::precise::tanh(x.real);
395 float tan_b = metal::precise::tan(x.imag);
396 float t1 = tanh_a * tan_b;
397 float denom = 1. + t1 * t1;
398 return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
399 };
400};
float log1p(float x)
Definition utils.h:414
float erfinv(float a)
Definition erf.h:42
float erf(float a)
Definition erf.h:11
float expm1f(float a)
Definition expm1f.h:80
METAL_FUNC bfloat16_t acosh(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t log(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t log10(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t log2(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t sin(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t cosh(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t tanh(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t tan(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t acos(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t exp(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t atanh(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t asinh(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t atan(bfloat16_t y_over_x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t sinh(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t cos(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t sqrt(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t asin(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t rsqrt(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t floor(bfloat16_t x)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t rint(bfloat16_t x)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t abs(bfloat16_t x)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t exp(bfloat16_t x)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t ceil(bfloat16_t x)
Definition bf16_math.h:234
Definition unary_ops.h:15
uint8_t operator()(uint8_t x)
Definition unary_ops.h:21
uint32_t operator()(uint32_t x)
Definition unary_ops.h:29
T operator()(T x)
Definition unary_ops.h:17
complex64_t operator()(complex64_t x)
Definition unary_ops.h:41
bool operator()(bool x)
Definition unary_ops.h:37
uint64_t operator()(uint64_t x)
Definition unary_ops.h:33
uint16_t operator()(uint16_t x)
Definition unary_ops.h:25
Definition unary_ops.h:46
T operator()(T x)
Definition unary_ops.h:48
Definition unary_ops.h:53
T operator()(T x)
Definition unary_ops.h:55
Definition unary_ops.h:60
T operator()(T x)
Definition unary_ops.h:62
Definition unary_ops.h:67
T operator()(T x)
Definition unary_ops.h:69
Definition unary_ops.h:74
T operator()(T x)
Definition unary_ops.h:76
Definition unary_ops.h:81
T operator()(T x)
Definition unary_ops.h:83
Definition unary_ops.h:88
int16_t operator()(int16_t x)
Definition unary_ops.h:98
bool operator()(bool x)
Definition unary_ops.h:126
uint32_t operator()(uint32_t x)
Definition unary_ops.h:118
int8_t operator()(int8_t x)
Definition unary_ops.h:94
T operator()(T x)
Definition unary_ops.h:90
int64_t operator()(int64_t x)
Definition unary_ops.h:106
uint64_t operator()(uint64_t x)
Definition unary_ops.h:122
uint8_t operator()(uint8_t x)
Definition unary_ops.h:110
uint16_t operator()(uint16_t x)
Definition unary_ops.h:114
int32_t operator()(int32_t x)
Definition unary_ops.h:102
Definition unary_ops.h:159
complex64_t operator()(complex64_t x)
Definition unary_ops.h:160
Definition unary_ops.h:131
complex64_t operator()(complex64_t x)
Definition unary_ops.h:138
T operator()(T x)
Definition unary_ops.h:133
Definition unary_ops.h:145
T operator()(T x)
Definition unary_ops.h:147
complex64_t operator()(complex64_t x)
Definition unary_ops.h:152
Definition unary_ops.h:165
T operator()(T x)
Definition unary_ops.h:167
Definition unary_ops.h:172
T operator()(T x)
Definition unary_ops.h:174
Definition unary_ops.h:179
complex64_t operator()(complex64_t x)
Definition unary_ops.h:185
T operator()(T x)
Definition unary_ops.h:181
Definition unary_ops.h:191
T operator()(T x)
Definition unary_ops.h:193
Definition unary_ops.h:198
int8_t operator()(int8_t x)
Definition unary_ops.h:204
int16_t operator()(int16_t x)
Definition unary_ops.h:208
int32_t operator()(int32_t x)
Definition unary_ops.h:212
uint16_t operator()(uint16_t x)
Definition unary_ops.h:224
uint64_t operator()(uint64_t x)
Definition unary_ops.h:232
uint32_t operator()(uint32_t x)
Definition unary_ops.h:228
int64_t operator()(int64_t x)
Definition unary_ops.h:216
bool operator()(bool x)
Definition unary_ops.h:236
uint8_t operator()(uint8_t x)
Definition unary_ops.h:220
T operator()(T x)
Definition unary_ops.h:200
Definition unary_ops.h:255
T operator()(T x)
Definition unary_ops.h:257
Definition unary_ops.h:262
T operator()(T x)
Definition unary_ops.h:264
Definition unary_ops.h:248
T operator()(T x)
Definition unary_ops.h:250
Definition unary_ops.h:241
T operator()(T x)
Definition unary_ops.h:243
Definition unary_ops.h:269
T operator()(T x)
Definition unary_ops.h:271
Definition unary_ops.h:276
T operator()(T x)
Definition unary_ops.h:278
Definition unary_ops.h:283
T operator()(T x)
Definition unary_ops.h:285
complex64_t operator()(complex64_t x)
Definition unary_ops.h:289
Definition unary_ops.h:363
T operator()(T x)
Definition unary_ops.h:365
Definition unary_ops.h:294
T operator()(T x)
Definition unary_ops.h:296
Definition unary_ops.h:302
T operator()(T x)
Definition unary_ops.h:304
uint32_t operator()(uint32_t x)
Definition unary_ops.h:308
complex64_t operator()(complex64_t x)
Definition unary_ops.h:312
Definition unary_ops.h:321
T operator()(T x)
Definition unary_ops.h:323
complex64_t operator()(complex64_t x)
Definition unary_ops.h:328
Definition unary_ops.h:335
T operator()(T x)
Definition unary_ops.h:337
complex64_t operator()(complex64_t x)
Definition unary_ops.h:342
Definition unary_ops.h:356
T operator()(T x)
Definition unary_ops.h:358
Definition unary_ops.h:349
T operator()(T x)
Definition unary_ops.h:351
Definition unary_ops.h:370
T operator()(T x)
Definition unary_ops.h:372
complex64_t operator()(complex64_t x)
Definition unary_ops.h:377
Definition unary_ops.h:386
complex64_t operator()(complex64_t x)
Definition unary_ops.h:393
T operator()(T x)
Definition unary_ops.h:388
Definition complex.h:20
float imag
Definition complex.h:22
float real
Definition complex.h:21