MLX
 
Loading...
Searching...
No Matches
unary_ops.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
5#include <stdint.h>
6#include <cmath>
7#include <complex>
8
10
11namespace mlx::core::detail {
12
13using namespace mlx::core::simd;
14
15#define SINGLE() \
16 template <typename T> \
17 T operator()(T x) { \
18 return (*this)(Simd<T, 1>(x)).value; \
19 }
20
21#define DEFAULT_OP(Op, op) \
22 struct Op { \
23 template <int N, typename T> \
24 Simd<T, N> operator()(Simd<T, N> x) { \
25 return simd::op(x); \
26 } \
27 SINGLE() \
28 };
29
60
61struct Imag {
62 template <int N>
67};
68
69struct Real {
70 template <int N>
75};
76
77struct Sigmoid {
78 template <int N, typename T>
80 return 1.0f / (1.0f + simd::exp(-x));
81 }
83};
84
85struct Sign {
86 template <int N, typename T>
88 auto z = Simd<T, N>{0};
89 if constexpr (std::is_unsigned_v<T>) {
90 return x != z;
91 } else if constexpr (std::is_same_v<T, complex64_t>) {
92 return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x)));
93 } else {
94 return simd::select(
95 x < z, Simd<T, N>{-1}, simd::select(x > z, Simd<T, N>{1}, z));
96 }
97 }
99};
100
101struct Square {
102 template <int N, typename T>
104 return x * x;
105 }
107};
108
109} // namespace mlx::core::detail
#define SINGLE()
Definition unary_ops.h:15
#define DEFAULT_OP(Op, op)
Definition unary_ops.h:21
Definition binary_ops.h:7
Simd< float16_t, N > sinh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:41
Simd< float16_t, N > atanh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:34
Simd< float16_t, N > log10(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:39
Simd< T, N > rint(Simd< T, N > v)
Definition accelerate_simd.h:127
Simd< float16_t, N > tan(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:42
Simd< T, N > abs(Simd< T, N > v)
Definition accelerate_simd.h:112
Simd< float16_t, N > acosh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:30
Simd< T, N > erf(Simd< T, N > x)
Definition math.h:137
Simd< T, 1 > conj(Simd< T, 1 > in)
Definition base_simd.h:85
Simd< float16_t, N > log2(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:38
Simd< T, N > erfinv(Simd< T, N > a_)
Definition math.h:151
Simd< T, N > exp(Simd< T, N > in)
Compute exp(x) in an optimizer friendly way as follows:
Definition math.h:28
Simd< float16_t, N > log(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:37
Simd< T, N > floor(Simd< T, N > v)
Definition accelerate_simd.h:113
Simd< float16_t, N > expm1(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:36
auto imag(Simd< T, 1 > in) -> Simd< decltype(std::imag(in.value)), 1 >
Definition base_simd.h:108
Simd< float16_t, N > asin(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:31
Simd< float16_t, N > tanh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:43
Simd< float16_t, N > atan(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:33
Simd< float16_t, N > asinh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:32
Simd< T, N > cos(Simd< T, N > x)
Definition math.h:128
Simd< T, N > sin(Simd< T, N > x)
Definition math.h:119
auto real(Simd< T, 1 > in) -> Simd< decltype(std::real(in.value)), 1 >
Definition base_simd.h:104
Simd< float16_t, N > log1p(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:40
Simd< T, N > ceil(Simd< T, N > v)
Definition accelerate_simd.h:120
Simd< T, N > sqrt(Simd< T, N > v)
Definition accelerate_simd.h:129
Simd< float16_t, N > acos(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:29
Simd< T, N > rsqrt(Simd< T, N > v)
Definition accelerate_simd.h:130
Simd< float16_t, N > cosh(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:35
Simd< T1, N > select(Simd< MaskT, N > mask, Simd< T1, N > x, Simd< T2, N > y)
Definition accelerate_simd.h:236
Definition unary_ops.h:30
Definition unary_ops.h:31
Definition unary_ops.h:32
Definition unary_ops.h:33
Definition unary_ops.h:34
Definition unary_ops.h:35
Definition unary_ops.h:36
Definition unary_ops.h:37
Definition unary_ops.h:38
Definition unary_ops.h:39
Definition unary_ops.h:40
Definition unary_ops.h:41
Definition unary_ops.h:42
Definition unary_ops.h:43
Definition unary_ops.h:44
Definition unary_ops.h:45
Definition unary_ops.h:46
Definition unary_ops.h:61
Simd< float, N > operator()(Simd< complex64_t, N > x)
Definition unary_ops.h:63
Definition unary_ops.h:49
Definition unary_ops.h:50
Definition unary_ops.h:48
Definition unary_ops.h:47
Definition unary_ops.h:51
Definition unary_ops.h:52
Definition unary_ops.h:69
Simd< float, N > operator()(Simd< complex64_t, N > x)
Definition unary_ops.h:71
Definition unary_ops.h:53
Definition unary_ops.h:57
Definition unary_ops.h:77
Simd< T, N > operator()(Simd< T, N > x)
Definition unary_ops.h:79
Definition unary_ops.h:85
Simd< T, N > operator()(Simd< T, N > x)
Definition unary_ops.h:87
Definition unary_ops.h:54
Definition unary_ops.h:55
Definition unary_ops.h:56
Definition unary_ops.h:101
Simd< T, N > operator()(Simd< T, N > x)
Definition unary_ops.h:103
Definition unary_ops.h:58
Definition unary_ops.h:59
Definition accelerate_simd.h:55