MLX
Loading...
Searching...
No Matches
unary.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <metal_integer>
6#include <metal_math>
7
12
13namespace {
14constant float inf = metal::numeric_limits<float>::infinity();
15}
16
17struct Abs {
18 template <typename T>
19 T operator()(T x) {
20 return metal::abs(x);
21 };
22 template <>
23 uint8_t operator()(uint8_t x) {
24 return x;
25 };
26 template <>
27 uint16_t operator()(uint16_t x) {
28 return x;
29 };
30 template <>
31 uint32_t operator()(uint32_t x) {
32 return x;
33 };
34 template <>
35 uint64_t operator()(uint64_t x) {
36 return x;
37 };
38 template <>
39 bool operator()(bool x) {
40 return x;
41 };
42 template <>
44 return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
45 };
46};
47
48struct ArcCos {
49 template <typename T>
50 T operator()(T x) {
51 return metal::precise::acos(x);
52 };
53};
54
55struct ArcCosh {
56 template <typename T>
57 T operator()(T x) {
58 return metal::precise::acosh(x);
59 };
60};
61
62struct ArcSin {
63 template <typename T>
64 T operator()(T x) {
65 return metal::precise::asin(x);
66 };
67};
68
69struct ArcSinh {
70 template <typename T>
71 T operator()(T x) {
72 return metal::precise::asinh(x);
73 };
74};
75
76struct ArcTan {
77 template <typename T>
78 T operator()(T x) {
79 return metal::precise::atan(x);
80 };
81};
82
83struct ArcTanh {
84 template <typename T>
85 T operator()(T x) {
86 return metal::precise::atanh(x);
87 };
88};
89
90struct Ceil {
91 template <typename T>
92 T operator()(T x) {
93 return metal::ceil(x);
94 };
95 template <>
96 int8_t operator()(int8_t x) {
97 return x;
98 };
99 template <>
100 int16_t operator()(int16_t x) {
101 return x;
102 };
103 template <>
104 int32_t operator()(int32_t x) {
105 return x;
106 };
107 template <>
108 int64_t operator()(int64_t x) {
109 return x;
110 };
111 template <>
112 uint8_t operator()(uint8_t x) {
113 return x;
114 };
115 template <>
116 uint16_t operator()(uint16_t x) {
117 return x;
118 };
119 template <>
120 uint32_t operator()(uint32_t x) {
121 return x;
122 };
123 template <>
124 uint64_t operator()(uint64_t x) {
125 return x;
126 };
127 template <>
128 bool operator()(bool x) {
129 return x;
130 };
131};
132
133struct Cos {
134 template <typename T>
135 T operator()(T x) {
136 return metal::precise::cos(x);
137 };
138
139 template <>
145};
146
147struct Cosh {
148 template <typename T>
149 T operator()(T x) {
150 return metal::precise::cosh(x);
151 };
152
153 template <>
159};
160
161struct Conjugate {
163 return complex64_t{x.real, -x.imag};
164 }
165};
166
167struct Erf {
168 template <typename T>
169 T operator()(T x) {
170 return static_cast<T>(erf(static_cast<float>(x)));
171 };
172};
173
174struct ErfInv {
175 template <typename T>
176 T operator()(T x) {
177 return static_cast<T>(erfinv(static_cast<float>(x)));
178 };
179};
180
181struct Exp {
182 template <typename T>
183 T operator()(T x) {
184 return metal::precise::exp(x);
185 };
186 template <>
191};
192
193struct Expm1 {
194 template <typename T>
195 T operator()(T x) {
196 return static_cast<T>(expm1f(static_cast<float>(x)));
197 };
198};
199
200struct Floor {
201 template <typename T>
202 T operator()(T x) {
203 return metal::floor(x);
204 };
205 template <>
206 int8_t operator()(int8_t x) {
207 return x;
208 };
209 template <>
210 int16_t operator()(int16_t x) {
211 return x;
212 };
213 template <>
214 int32_t operator()(int32_t x) {
215 return x;
216 };
217 template <>
218 int64_t operator()(int64_t x) {
219 return x;
220 };
221 template <>
222 uint8_t operator()(uint8_t x) {
223 return x;
224 };
225 template <>
226 uint16_t operator()(uint16_t x) {
227 return x;
228 };
229 template <>
230 uint32_t operator()(uint32_t x) {
231 return x;
232 };
233 template <>
234 uint64_t operator()(uint64_t x) {
235 return x;
236 };
237 template <>
238 bool operator()(bool x) {
239 return x;
240 };
241};
242
243struct Log {
244 template <typename T>
245 T operator()(T x) {
246 return metal::precise::log(x);
247 };
248};
249
250struct Log2 {
251 template <typename T>
252 T operator()(T x) {
253 return metal::precise::log2(x);
254 };
255};
256
257struct Log10 {
258 template <typename T>
259 T operator()(T x) {
260 return metal::precise::log10(x);
261 };
262};
263
264struct Log1p {
265 template <typename T>
266 T operator()(T x) {
267 return log1p(x);
268 };
269};
270
272 template <typename T>
273 T operator()(T x) {
274 return !x;
275 };
276};
277
278struct Negative {
279 template <typename T>
280 T operator()(T x) {
281 return -x;
282 };
283};
284
285struct Round {
286 template <typename T>
287 T operator()(T x) {
288 return metal::rint(x);
289 };
290 template <>
294};
295
296struct Sigmoid {
297 template <typename T>
298 T operator()(T x) {
299 auto y = 1 / (1 + metal::exp(-metal::abs(x)));
300 return (x < 0) ? 1 - y : y;
301 }
302};
303
304struct Sign {
305 template <typename T>
306 T operator()(T x) {
307 return (x > T(0)) - (x < T(0));
308 };
309 template <>
310 uint32_t operator()(uint32_t x) {
311 return x != 0;
312 };
313};
314
315struct Sin {
316 template <typename T>
317 T operator()(T x) {
318 return metal::precise::sin(x);
319 };
320
321 template <>
327};
328
329struct Sinh {
330 template <typename T>
331 T operator()(T x) {
332 return metal::precise::sinh(x);
333 };
334
335 template <>
341};
342
343struct Square {
344 template <typename T>
345 T operator()(T x) {
346 return x * x;
347 };
348};
349
350struct Sqrt {
351 template <typename T>
352 T operator()(T x) {
353 return metal::precise::sqrt(x);
354 };
355};
356
357struct Rsqrt {
358 template <typename T>
359 T operator()(T x) {
360 return metal::precise::rsqrt(x);
361 };
362};
363
364struct Tan {
365 template <typename T>
366 T operator()(T x) {
367 return metal::precise::tan(x);
368 };
369
370 template <>
372 float tan_a = metal::precise::tan(x.real);
373 float tanh_b = metal::precise::tanh(x.imag);
374 float t1 = tan_a * tanh_b;
375 float denom = 1. + t1 * t1;
376 return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
377 };
378};
379
380struct Tanh {
381 template <typename T>
382 T operator()(T x) {
383 return metal::precise::tanh(x);
384 };
385
386 template <>
388 float tanh_a = metal::precise::tanh(x.real);
389 float tan_b = metal::precise::tan(x.imag);
390 float t1 = tanh_a * tan_b;
391 float denom = 1. + t1 * t1;
392 return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
393 };
394};
float log1p(float x)
Definition utils.h:298
float erfinv(float a)
Definition erf.h:43
float erf(float a)
Definition erf.h:12
float expm1f(float a)
Definition expm1f.h:80
METAL_FUNC bfloat16_t acosh(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t log(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t log10(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t log2(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t sin(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t cosh(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t tanh(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t tan(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t acos(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t exp(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t atanh(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t asinh(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t atan(bfloat16_t y_over_x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t sinh(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t cos(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t sqrt(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t asin(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t rsqrt(bfloat16_t x)
Definition bf16_math.h:252
METAL_FUNC bfloat16_t floor(bfloat16_t x)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t rint(bfloat16_t x)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t abs(bfloat16_t x)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t exp(bfloat16_t x)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t ceil(bfloat16_t x)
Definition bf16_math.h:234
Definition unary.h:17
uint8_t operator()(uint8_t x)
Definition unary.h:23
uint32_t operator()(uint32_t x)
Definition unary.h:31
T operator()(T x)
Definition unary.h:19
complex64_t operator()(complex64_t x)
Definition unary.h:43
bool operator()(bool x)
Definition unary.h:39
uint64_t operator()(uint64_t x)
Definition unary.h:35
uint16_t operator()(uint16_t x)
Definition unary.h:27
Definition unary.h:48
T operator()(T x)
Definition unary.h:50
Definition unary.h:55
T operator()(T x)
Definition unary.h:57
Definition unary.h:62
T operator()(T x)
Definition unary.h:64
Definition unary.h:69
T operator()(T x)
Definition unary.h:71
Definition unary.h:76
T operator()(T x)
Definition unary.h:78
Definition unary.h:83
T operator()(T x)
Definition unary.h:85
Definition unary.h:90
int16_t operator()(int16_t x)
Definition unary.h:100
bool operator()(bool x)
Definition unary.h:128
uint32_t operator()(uint32_t x)
Definition unary.h:120
int8_t operator()(int8_t x)
Definition unary.h:96
T operator()(T x)
Definition unary.h:92
int64_t operator()(int64_t x)
Definition unary.h:108
uint64_t operator()(uint64_t x)
Definition unary.h:124
uint8_t operator()(uint8_t x)
Definition unary.h:112
uint16_t operator()(uint16_t x)
Definition unary.h:116
int32_t operator()(int32_t x)
Definition unary.h:104
Definition unary.h:161
complex64_t operator()(complex64_t x)
Definition unary.h:162
Definition unary.h:133
complex64_t operator()(complex64_t x)
Definition unary.h:140
T operator()(T x)
Definition unary.h:135
Definition unary.h:147
T operator()(T x)
Definition unary.h:149
complex64_t operator()(complex64_t x)
Definition unary.h:154
Definition unary.h:167
T operator()(T x)
Definition unary.h:169
Definition unary.h:174
T operator()(T x)
Definition unary.h:176
Definition unary.h:181
complex64_t operator()(complex64_t x)
Definition unary.h:187
T operator()(T x)
Definition unary.h:183
Definition unary.h:193
T operator()(T x)
Definition unary.h:195
Definition unary.h:200
int8_t operator()(int8_t x)
Definition unary.h:206
int16_t operator()(int16_t x)
Definition unary.h:210
int32_t operator()(int32_t x)
Definition unary.h:214
uint16_t operator()(uint16_t x)
Definition unary.h:226
uint64_t operator()(uint64_t x)
Definition unary.h:234
uint32_t operator()(uint32_t x)
Definition unary.h:230
int64_t operator()(int64_t x)
Definition unary.h:218
bool operator()(bool x)
Definition unary.h:238
uint8_t operator()(uint8_t x)
Definition unary.h:222
T operator()(T x)
Definition unary.h:202
Definition unary.h:257
T operator()(T x)
Definition unary.h:259
Definition unary.h:264
T operator()(T x)
Definition unary.h:266
Definition unary.h:250
T operator()(T x)
Definition unary.h:252
Definition unary.h:243
T operator()(T x)
Definition unary.h:245
Definition unary.h:271
T operator()(T x)
Definition unary.h:273
Definition unary.h:278
T operator()(T x)
Definition unary.h:280
Definition unary.h:285
T operator()(T x)
Definition unary.h:287
complex64_t operator()(complex64_t x)
Definition unary.h:291
Definition unary.h:357
T operator()(T x)
Definition unary.h:359
Definition unary.h:296
T operator()(T x)
Definition unary.h:298
Definition unary.h:304
T operator()(T x)
Definition unary.h:306
uint32_t operator()(uint32_t x)
Definition unary.h:310
Definition unary.h:315
T operator()(T x)
Definition unary.h:317
complex64_t operator()(complex64_t x)
Definition unary.h:322
Definition unary.h:329
T operator()(T x)
Definition unary.h:331
complex64_t operator()(complex64_t x)
Definition unary.h:336
Definition unary.h:350
T operator()(T x)
Definition unary.h:352
Definition unary.h:343
T operator()(T x)
Definition unary.h:345
Definition unary.h:364
T operator()(T x)
Definition unary.h:366
complex64_t operator()(complex64_t x)
Definition unary.h:371
Definition unary.h:380
complex64_t operator()(complex64_t x)
Definition unary.h:387
T operator()(T x)
Definition unary.h:382
Definition complex.h:20
float imag
Definition complex.h:22
float real
Definition complex.h:21