MLX
 
Loading...
Searching...
No Matches
neon_fp16_simd.h
Go to the documentation of this file.
1#pragma once
2
3#include <arm_neon.h>
4
6
7namespace mlx::core::simd {
8
9constexpr int N = 8;
10
11template <>
12struct Simd<float16_t, N> {
13 static constexpr int size = N;
15
17
18 template <typename U>
19 Simd<float16_t, N>(U v) : value(vdupq_n_f16(v)){};
20
21 Simd<float16_t, N>(float16x8_t v) : value(v){};
22
23 Simd<float16_t, N>(Simd<float, N> other) {
24 auto f32x4_a = *(float32x4_t*)(&other);
25 auto f32x4_b = *((float32x4_t*)(&other) + 1);
26 value = vcvt_high_f16_f32(vcvt_f16_f32(f32x4_a), f32x4_b);
27 };
28
29 Simd<float16_t, N>(Simd<uint16_t, N> other) {
30 value = vcvtq_f16_u16(*(uint16x8_t*)(&other.value));
31 };
32
33 operator Simd<int16_t, N>() {
34 auto v = vcvtq_s16_f16(value);
35 return load<int16_t, N>((int16_t*)&v);
36 };
37
38 operator Simd<float, N>() {
39 float32x4x2_t v;
40 v.val[0] = vcvt_f32_f16(*(float16x4_t*)(&value));
41 v.val[1] = vcvt_high_f32_f16(value);
42 return load<float, N>((float*)&v);
43 }
44 float16_t operator[](int idx) const {
45 return reinterpret_cast<const float16_t*>(&value)[idx];
46 }
47
49 return reinterpret_cast<float16_t*>(&value)[idx];
50 }
51
52 float16x8_t value;
53};
54
55#define DEFINE_NEON_UNARY_OP(name, op) \
56 inline Simd<float16_t, N> name(Simd<float16_t, N> a) { \
57 return Simd<float16_t, N>{op(a.value)}; \
58 }
59
67
68#define DEFINE_NEON_BINARY_OP(name, op) \
69 inline Simd<float16_t, N> name(Simd<float16_t, N> a, Simd<float16_t, N> b) { \
70 return op(a.value, b.value); \
71 } \
72 template <typename T> \
73 Simd<float16_t, N> name(Simd<float16_t, N> a, T b) { \
74 return op(a.value, Simd<float16_t, N>(b).value); \
75 } \
76 template <typename T> \
77 Simd<float16_t, N> name(T a, Simd<float16_t, N> b) { \
78 return op(Simd<float16_t, N>(a).value, b.value); \
79 }
80
82 auto out = vceqzq_f16(v.value);
83 return Simd<uint16_t, N>(*(uint16_t*)&out);
84}
85
87 return vnegq_f16(v.value);
88}
89
92DEFINE_NEON_BINARY_OP(operator+, vaddq_f16)
93DEFINE_NEON_BINARY_OP(operator-, vsubq_f16)
94DEFINE_NEON_BINARY_OP(operator*, vmulq_f16)
95DEFINE_NEON_BINARY_OP(operator/, vdivq_f16)
96
97#define DEFINE_NEON_COMPARISON(Op, op) \
98 template <typename T> \
99 Simd<bool, N> operator Op(Simd<float16_t, N> a, T b) { \
100 auto out = op(a.value, Simd<float16_t, N>(b).value); \
101 return Simd<uint16_t, N>(*(uint16_t*)(&out)); \
102 } \
103 template <typename T> \
104 Simd<bool, N> operator Op(T a, Simd<float16_t, N> b) { \
105 auto out = op(Simd<float16_t, N>(a).value, b.value); \
106 return Simd<uint16_t, N>(*(uint16_t*)(&out)); \
107 } \
108 inline Simd<bool, N> operator Op( \
109 Simd<float16_t, N> a, Simd<float16_t, N> b) { \
110 auto out = op(a.value, b.value); \
111 return Simd<uint16_t, N>(*(uint16_t*)(&out)); \
112 }
113
116DEFINE_NEON_COMPARISON(<=, vcleq_f16)
118DEFINE_NEON_COMPARISON(<, vcltq_f16)
119
120template <typename T>
121Simd<bool, N> operator!=(Simd<float16_t, N> a, T b) {
122 return !(a == b);
123}
124template <typename T>
126 return !(a == b);
127}
129 return !(a == b);
130}
131
135 return Simd<uint16_t, N>((a != 0) || (b != 0));
136}
137template <typename T>
139 return Simd<uint16_t, N>((a != 0) || (b != 0));
140}
141template <typename T>
143 return Simd<uint16_t, N>((a != 0) || (b != 0));
144}
148 return Simd<uint16_t, N>((a != 0) && (b != 0));
149}
150template <typename T>
152 return Simd<uint16_t, N>((a != 0) && (b != 0));
153}
154template <typename T>
156 return Simd<uint16_t, N>((a != 0) && (b != 0));
157}
158
159template <>
161 return v != v;
162}
163
164template <>
165inline Simd<float16_t, N>
169
170template <typename T>
174
175template <typename MaskT>
176Simd<float16_t, N>
178 return vbslq_f16(Simd<uint16_t, N>(mask).value, x.value, y.value);
179}
180
181// Reductions
183 float16x4_t y;
184 y = vpmax_f16(vget_low_f16(x.value), vget_high_f16(x.value));
185 y = vpmax_f16(y, y);
186 y = vpmax_f16(y, y);
187 return vget_lane_f16(y, 0);
188}
190 float16x4_t y;
191 y = vpmin_f16(vget_low_f16(x.value), vget_high_f16(x.value));
192 y = vpmin_f16(y, y);
193 y = vpmin_f16(y, y);
194 return vget_lane_f16(y, 0);
195}
197 float16x4_t y;
198 y = vpadd_f16(vget_low_f16(x.value), vget_high_f16(x.value));
199 y = vpadd_f16(y, y);
200 y = vpadd_f16(y, y);
201 return vget_lane_f16(y, 0);
202}
204 auto hx = vmul_f16(vget_low_f16(x.value), vget_high_f16(x.value));
205 auto out = hx[0];
206 hx[0] *= hx[1];
207 hx[0] *= hx[2];
208 hx[0] *= hx[3];
209 return hx[0];
210}
211
212} // namespace mlx::core::simd
Definition accelerate_fp16_simd.h:9
Simd< bool, N > isnan(Simd< T, N > v)
Definition accelerate_simd.h:146
constexpr int N
Definition neon_fp16_simd.h:9
Simd< T, N > minimum(Simd< T, N > a, Simd< T, N > b)
Definition accelerate_simd.h:215
T prod(Simd< T, N > x)
Definition accelerate_simd.h:297
Simd< T, N > rint(Simd< T, N > v)
Definition accelerate_simd.h:127
Simd< T, N > load(const T *x)
Definition base_simd.h:28
Simd< bool, N > operator!=(Simd< T, N > a, U b)
Definition accelerate_simd.h:201
Simd< T, N > abs(Simd< T, N > v)
Definition accelerate_simd.h:112
T sum(Simd< T, N > x)
Definition accelerate_simd.h:284
T max(Simd< T, N > x)
Definition accelerate_simd.h:288
Simd< bool, N > operator!(Simd< T, N > v)
Definition accelerate_simd.h:152
Simd< T, N > maximum(Simd< T, N > a, Simd< T, N > b)
Definition accelerate_simd.h:209
Simd< T, N > operator&&(Simd< T, N > x, U y)
Definition accelerate_simd.h:179
Simd< T, N > floor(Simd< T, N > v)
Definition accelerate_simd.h:113
Simd< T, N > fma(Simd< T, N > x, Simd< T, N > y, U z)
Definition accelerate_simd.h:269
Simd< T, N > operator||(Simd< T, N > x, U y)
Definition accelerate_simd.h:180
T min(Simd< T, N > x)
Definition accelerate_simd.h:292
Simd< T, N > ceil(Simd< T, N > v)
Definition accelerate_simd.h:120
Simd< T, N > recip(Simd< T, N > v)
Definition accelerate_simd.h:131
Simd< T, N > sqrt(Simd< T, N > v)
Definition accelerate_simd.h:129
Simd< T, N > clamp(Simd< T, N > v, Simd< T, N > min, Simd< T, N > max)
Definition accelerate_simd.h:264
Simd< T, N > rsqrt(Simd< T, N > v)
Definition accelerate_simd.h:130
Simd< T, N > operator-(Simd< T, N > v)
Definition accelerate_simd.h:136
Simd< T1, N > select(Simd< MaskT, N > mask, Simd< T1, N > x, Simd< T2, N > y)
Definition accelerate_simd.h:236
struct _MLX_Float16 float16_t
Definition half_types.h:17
#define DEFINE_NEON_BINARY_OP(name, op)
Definition neon_fp16_simd.h:68
#define DEFINE_NEON_COMPARISON(Op, op)
Definition neon_fp16_simd.h:97
#define DEFINE_NEON_UNARY_OP(name, op)
Definition neon_fp16_simd.h:55
Simd()
Definition neon_fp16_simd.h:16
static constexpr int size
Definition neon_fp16_simd.h:13
float16_t scalar_t
Definition neon_fp16_simd.h:14
float16_t operator[](int idx) const
Definition neon_fp16_simd.h:44
float16_t & operator[](int idx)
Definition neon_fp16_simd.h:48
float16x8_t value
Definition neon_fp16_simd.h:52
Definition accelerate_simd.h:55
asd::Vector< scalar_t, N >::packed_t value
Definition accelerate_simd.h:80
Simd()
Definition accelerate_simd.h:59