MLX
 
Loading...
Searching...
No Matches
binary_ops.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
6
8
9using namespace mlx::core::simd;
10
11#define BINARY_SINGLE() \
12 template <typename T> \
13 T operator()(T x, T y) { \
14 return (*this)(Simd<T, 1>(x), Simd<T, 1>(y)).value; \
15 }
16
17#define DEFAULT_BINARY_OP(Op, op) \
18 struct Op { \
19 template <int N, typename T> \
20 Simd<T, N> operator()(Simd<T, N> x, Simd<T, N> y) { \
21 return op(x, y); \
22 } \
23 BINARY_SINGLE() \
24 };
25
42
43#define DEFAULT_BOOL_OP(Op, op) \
44 struct Op { \
45 template <int N, typename T> \
46 Simd<bool, N> operator()(Simd<T, N> x, Simd<T, N> y) { \
47 return op(x, y); \
48 } \
49 template <typename T> \
50 bool operator()(T x, T y) { \
51 return (*this)(Simd<T, 1>(x), Simd<T, 1>(y)).value; \
52 } \
53 };
54
61
62struct NaNEqual {
63 template <int N, typename T>
65 return x == y || (isnan(x) && isnan(y));
66 }
67 template <typename T>
68 bool operator()(T x, T y) {
69 return (*this)(Simd<T, 1>(x), Simd<T, 1>(y)).value;
70 }
71};
72
73struct LogAddExp {
74 template <int N, typename T>
76 auto maxval = maximum(x, y);
77 auto minval = minimum(x, y);
78 auto mask = minval == -inf || maxval == inf;
79 auto out = maxval + log1p(exp(minval - maxval));
80 return select(mask, Simd<T, N>(maxval), Simd<T, N>(out));
81 }
83};
84
85struct Select {
86 template <typename T>
87 T operator()(bool condition, T x, T y) {
88 return (*this)(Simd<bool, 1>(condition), Simd<T, 1>(x), Simd<T, 1>(y))
89 .value;
90 }
91
92 template <int N, typename T>
94 return select(condition, x, y);
95 }
96};
97
98} // namespace mlx::core::detail
#define DEFAULT_BINARY_OP(Op, op)
Definition binary_ops.h:17
#define DEFAULT_BOOL_OP(Op, op)
Definition binary_ops.h:43
#define BINARY_SINGLE()
Definition binary_ops.h:11
Definition binary_ops.h:7
Definition accelerate_fp16_simd.h:9
Simd< bool, N > isnan(Simd< T, N > v)
Definition accelerate_simd.h:146
Simd< T, N > minimum(Simd< T, N > a, Simd< T, N > b)
Definition accelerate_simd.h:215
Simd< float16_t, N > pow(Simd< float16_t, N > x, Simd< float16_t, N > y)
Definition accelerate_fp16_simd.h:54
Simd< float16_t, N > atan2(Simd< float16_t, N > x, Simd< float16_t, N > y)
Definition accelerate_fp16_simd.h:52
constexpr float inf
Definition math.h:9
Simd< T, N > maximum(Simd< T, N > a, Simd< T, N > b)
Definition accelerate_simd.h:209
Simd< T, N > exp(Simd< T, N > in)
Compute exp(x) in an optimizer friendly way as follows:
Definition math.h:28
Simd< float16_t, N > remainder(Simd< float16_t, N > x, Simd< float16_t, N > y)
Definition accelerate_fp16_simd.h:53
Simd< float16_t, N > log1p(Simd< float16_t, N > v)
Definition accelerate_fp16_simd.h:40
Simd< T1, N > select(Simd< MaskT, N > mask, Simd< T1, N > x, Simd< T2, N > y)
Definition accelerate_simd.h:236
Definition binary_ops.h:26
Definition binary_ops.h:27
Definition binary_ops.h:33
Definition binary_ops.h:34
Definition binary_ops.h:35
Definition binary_ops.h:28
Definition binary_ops.h:55
Definition binary_ops.h:57
Definition binary_ops.h:56
Definition binary_ops.h:36
Definition binary_ops.h:59
Definition binary_ops.h:58
Definition binary_ops.h:73
Simd< T, N > operator()(Simd< T, N > x, Simd< T, N > y)
Definition binary_ops.h:75
Definition binary_ops.h:31
Definition binary_ops.h:32
Definition binary_ops.h:39
Definition binary_ops.h:40
Definition binary_ops.h:29
Definition binary_ops.h:62
bool operator()(T x, T y)
Definition binary_ops.h:68
Simd< bool, N > operator()(Simd< T, N > x, Simd< T, N > y)
Definition binary_ops.h:64
Definition binary_ops.h:60
Definition binary_ops.h:41
Definition binary_ops.h:38
Definition binary_ops.h:37
Definition binary_ops.h:85
Simd< T, N > operator()(Simd< bool, N > condition, Simd< T, N > x, Simd< T, N > y)
Definition binary_ops.h:93
T operator()(bool condition, T x, T y)
Definition binary_ops.h:87
Definition binary_ops.h:30
Definition accelerate_simd.h:55