MLX
Loading...
Searching...
No Matches
bf16_math.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
6
8// Metal math for bfloat16
10
11/*
12
13Following the Metal Shading Language Specification (Metal 3.1)
14
15"bfloat is an extended itypeing point type that only allows implicit conversion
16 to a type of greater itypeing point rank. While bfloat can be implicitly
17 converted to itype, it cannot be implicitly converted to half, and neither
18 itype nor half can be implicitly converted to bfloat."
19
20Further, as far as I can tell, the stdlib math/simd functions are not defined
21for bfloat and calling with an argument of type bfloat will result in that
22argument getting implicitly converted to itype which then returns an output
23that is (likely) a itype which cannot be implicitly converted into a bfloat
24
25This leads to situations where
26bfloat a = 5.0bf;
27bfloat b = metal::abs(a); // this will throw an error since abs return itype
28bfloat c = static_cast<bfloat>(metal::abs(a)); // this is fine
29
30For the moment, I will be adding overloaded instantiations of the math
31functions to accordingly automatically handle the casting
32
33*/
34
35#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \
36 \
37 METAL_FUNC otype abs(itype x) { \
38 return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
39 } \
40 METAL_FUNC otype acos(itype x) { \
41 return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \
42 } \
43 METAL_FUNC otype acosh(itype x) { \
44 return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \
45 } \
46 METAL_FUNC otype asin(itype x) { \
47 return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \
48 } \
49 METAL_FUNC otype asinh(itype x) { \
50 return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \
51 } \
52 METAL_FUNC otype atan(itype y_over_x) { \
53 return static_cast<otype>( \
54 __metal_atan(static_cast<ctype>(y_over_x), mfast)); \
55 } \
56 METAL_FUNC otype atan2(itype y, itype x) { \
57 return static_cast<otype>( \
58 __metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \
59 } \
60 METAL_FUNC otype atanh(itype x) { \
61 return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \
62 } \
63 METAL_FUNC otype ceil(itype x) { \
64 return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \
65 } \
66 METAL_FUNC otype cos(itype x) { \
67 return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \
68 } \
69 METAL_FUNC otype cosh(itype x) { \
70 return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \
71 } \
72 METAL_FUNC otype cospi(itype x) { \
73 return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \
74 } \
75 METAL_FUNC otype divide(itype x, itype y) { \
76 return static_cast<otype>( \
77 __metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
78 } \
79 METAL_FUNC otype exp(itype x) { \
80 return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \
81 } \
82 METAL_FUNC otype exp10(itype x) { \
83 return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \
84 } \
85 METAL_FUNC otype exp2(itype x) { \
86 return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \
87 } \
88 METAL_FUNC otype fabs(itype x) { \
89 return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
90 } \
91 METAL_FUNC otype fdim(itype x, itype y) { \
92 ctype t = static_cast<ctype>(x - y); \
93 return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \
94 } \
95 METAL_FUNC otype floor(itype x) { \
96 return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \
97 } \
98 METAL_FUNC otype fma(itype x, itype y, itype z) { \
99 return static_cast<otype>(__metal_fma( \
100 static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \
101 } \
102 METAL_FUNC otype fmax(itype x, itype y) { \
103 return static_cast<otype>( \
104 __metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
105 } \
106 METAL_FUNC otype fmax3(itype x, itype y, itype z) { \
107 return static_cast<otype>(__metal_fmax3( \
108 static_cast<ctype>(x), \
109 static_cast<ctype>(y), \
110 static_cast<ctype>(z), \
111 mfast)); \
112 } \
113 METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \
114 return static_cast<otype>(__metal_fmedian3( \
115 static_cast<ctype>(x), \
116 static_cast<ctype>(y), \
117 static_cast<ctype>(z), \
118 mfast)); \
119 } \
120 METAL_FUNC otype fmin(itype x, itype y) { \
121 return static_cast<otype>( \
122 __metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
123 } \
124 METAL_FUNC otype fmin3(itype x, itype y, itype z) { \
125 return static_cast<otype>(__metal_fmin3( \
126 static_cast<ctype>(x), \
127 static_cast<ctype>(y), \
128 static_cast<ctype>(z), \
129 mfast)); \
130 } \
131 METAL_FUNC otype fmod(itype x, itype y) { \
132 return static_cast<otype>( \
133 __metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
134 } \
135 METAL_FUNC otype fract(itype x) { \
136 return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \
137 } \
138 METAL_FUNC otype frexp(itype x, thread int& exp) { \
139 return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \
140 } \
141 METAL_FUNC otype ldexp(itype x, int k) { \
142 return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \
143 } \
144 METAL_FUNC otype log(itype x) { \
145 return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \
146 } \
147 METAL_FUNC otype log10(itype x) { \
148 return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \
149 } \
150 METAL_FUNC otype log2(itype x) { \
151 return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \
152 } \
153 METAL_FUNC otype max(itype x, itype y) { \
154 return static_cast<otype>( \
155 __metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
156 } \
157 METAL_FUNC otype max3(itype x, itype y, itype z) { \
158 return static_cast<otype>(__metal_fmax3( \
159 static_cast<ctype>(x), \
160 static_cast<ctype>(y), \
161 static_cast<ctype>(z), \
162 mfast)); \
163 } \
164 METAL_FUNC otype median3(itype x, itype y, itype z) { \
165 return static_cast<otype>(__metal_fmedian3( \
166 static_cast<ctype>(x), \
167 static_cast<ctype>(y), \
168 static_cast<ctype>(z), \
169 mfast)); \
170 } \
171 METAL_FUNC otype min(itype x, itype y) { \
172 return static_cast<otype>( \
173 __metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
174 } \
175 METAL_FUNC otype min3(itype x, itype y, itype z) { \
176 return static_cast<otype>(__metal_fmin3( \
177 static_cast<ctype>(x), \
178 static_cast<ctype>(y), \
179 static_cast<ctype>(z), \
180 mfast)); \
181 } \
182 METAL_FUNC otype nextafter(itype x, itype y) { \
183 return static_cast<otype>( \
184 __metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \
185 } \
186 METAL_FUNC otype pow(itype x, itype y) { \
187 return static_cast<otype>( \
188 __metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
189 } \
190 METAL_FUNC otype powr(itype x, itype y) { \
191 return static_cast<otype>( \
192 __metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
193 } \
194 METAL_FUNC otype rint(itype x) { \
195 return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \
196 } \
197 METAL_FUNC otype round(itype x) { \
198 return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \
199 } \
200 METAL_FUNC otype rsqrt(itype x) { \
201 return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \
202 } \
203 METAL_FUNC otype sin(itype x) { \
204 return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \
205 } \
206 METAL_FUNC otype sinh(itype x) { \
207 return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \
208 } \
209 METAL_FUNC otype sinpi(itype x) { \
210 return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \
211 } \
212 METAL_FUNC otype sqrt(itype x) { \
213 return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \
214 } \
215 METAL_FUNC otype tan(itype x) { \
216 return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \
217 } \
218 METAL_FUNC otype tanh(itype x) { \
219 return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \
220 } \
221 METAL_FUNC otype tanpi(itype x) { \
222 return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \
223 } \
224 METAL_FUNC otype trunc(itype x) { \
225 return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \
226 }
227
228namespace metal {
229
233 float,
234 __METAL_MAYBE_FAST_MATH__);
235
236namespace fast {
237
241 float,
242 __METAL_FAST_MATH__);
243
244} // namespace fast
245
246namespace precise {
247
251 float,
252 __METAL_PRECISE_MATH__);
253
254} // namespace precise
255
256} // namespace metal
257
259// Metal simd for bfloat16
261
262#define instantiate_metal_simd_comm_funcs( \
263 itype, otype, ctype, itype_to_ctype, ctype_to_otype) \
264 \
265 METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \
266 return ctype_to_otype( \
267 __metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \
268 } \
269 \
270 METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \
271 return ctype_to_otype( \
272 __metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \
273 } \
274 \
275 METAL_FUNC otype simd_shuffle_and_fill_down( \
276 itype data, itype filling_data, ushort delta, ushort modulo) { \
277 return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
278 itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
279 } \
280 \
281 METAL_FUNC otype simd_shuffle_and_fill_down( \
282 itype data, itype filling_data, ushort delta) { \
283 return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
284 itype_to_ctype(data), \
285 itype_to_ctype(filling_data), \
286 delta, \
287 __metal_get_simdgroup_size(ushort()))); \
288 } \
289 \
290 METAL_FUNC otype simd_shuffle_and_fill_up( \
291 itype data, itype filling_data, ushort delta, ushort modulo) { \
292 return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
293 itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
294 } \
295 \
296 METAL_FUNC otype simd_shuffle_and_fill_up( \
297 itype data, itype filling_data, ushort delta) { \
298 return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
299 itype_to_ctype(data), \
300 itype_to_ctype(filling_data), \
301 delta, \
302 __metal_get_simdgroup_size(ushort()))); \
303 } \
304 \
305 METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \
306 return ctype_to_otype( \
307 __metal_simd_shuffle_down(itype_to_ctype(data), delta)); \
308 } \
309 \
310 METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \
311 return ctype_to_otype( \
312 __metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \
313 } \
314 \
315 METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \
316 return ctype_to_otype( \
317 __metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \
318 } \
319 \
320 METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \
321 return ctype_to_otype( \
322 __metal_simd_shuffle_up(itype_to_ctype(data), delta)); \
323 } \
324 \
325 METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \
326 return ctype_to_otype( \
327 __metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \
328 }
329
330#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \
331 \
332 METAL_FUNC otype simd_max(itype data) { \
333 return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \
334 } \
335 \
336 METAL_FUNC otype simd_min(itype data) { \
337 return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \
338 } \
339 \
340 METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \
341 return static_cast<otype>( \
342 __metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \
343 } \
344 \
345 METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \
346 return static_cast<otype>( \
347 __metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \
348 } \
349 \
350 METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \
351 return static_cast<otype>( \
352 __metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \
353 } \
354 \
355 METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \
356 return static_cast<otype>( \
357 __metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \
358 } \
359 \
360 METAL_FUNC otype simd_product(itype data) { \
361 return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \
362 } \
363 \
364 METAL_FUNC otype simd_sum(itype data) { \
365 return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \
366 } \
367 \
368 METAL_FUNC otype simd_xor(itype data) { \
369 return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
370 }
371
372#if defined METAL_3_1 || (__METAL_VERSION__ >= 310)
373
374#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
375#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
376
377#else
378
379#define bfloat16_to_uint16(x) x.bits_
380#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
381
382#endif
383
384namespace metal {
385
389 uint16_t,
393
394} // namespace metal
#define uint16_to_bfloat16(x)
Definition bf16_math.h:380
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype)
Definition bf16_math.h:330
#define bfloat16_to_uint16(x)
Definition bf16_math.h:379
#define instantiate_metal_math_funcs(itype, otype, ctype, mfast)
Definition bf16_math.h:35
#define instantiate_metal_simd_comm_funcs( itype, otype, ctype, itype_to_ctype, ctype_to_otype)
Definition bf16_math.h:262
Definition bf16.h:265
Definition bf16.h:54