33#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \
35 METAL_FUNC otype abs(itype x) { \
36 return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
38 METAL_FUNC otype acos(itype x) { \
39 return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \
41 METAL_FUNC otype acosh(itype x) { \
42 return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \
44 METAL_FUNC otype asin(itype x) { \
45 return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \
47 METAL_FUNC otype asinh(itype x) { \
48 return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \
50 METAL_FUNC otype atan(itype y_over_x) { \
51 return static_cast<otype>( \
52 __metal_atan(static_cast<ctype>(y_over_x), mfast)); \
54 METAL_FUNC otype atan2(itype y, itype x) { \
55 return static_cast<otype>( \
56 __metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \
58 METAL_FUNC otype atanh(itype x) { \
59 return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \
61 METAL_FUNC otype ceil(itype x) { \
62 return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \
64 METAL_FUNC otype cos(itype x) { \
65 return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \
67 METAL_FUNC otype cosh(itype x) { \
68 return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \
70 METAL_FUNC otype cospi(itype x) { \
71 return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \
73 METAL_FUNC otype divide(itype x, itype y) { \
74 return static_cast<otype>( \
75 __metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
77 METAL_FUNC otype exp(itype x) { \
78 return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \
80 METAL_FUNC otype exp10(itype x) { \
81 return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \
83 METAL_FUNC otype exp2(itype x) { \
84 return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \
86 METAL_FUNC otype fabs(itype x) { \
87 return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
89 METAL_FUNC otype fdim(itype x, itype y) { \
90 ctype t = static_cast<ctype>(x - y); \
91 return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \
93 METAL_FUNC otype floor(itype x) { \
94 return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \
96 METAL_FUNC otype fma(itype x, itype y, itype z) { \
97 return static_cast<otype>(__metal_fma( \
98 static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \
100 METAL_FUNC otype fmax(itype x, itype y) { \
101 return static_cast<otype>( \
102 __metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
104 METAL_FUNC otype fmax3(itype x, itype y, itype z) { \
105 return static_cast<otype>(__metal_fmax3( \
106 static_cast<ctype>(x), \
107 static_cast<ctype>(y), \
108 static_cast<ctype>(z), \
111 METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \
112 return static_cast<otype>(__metal_fmedian3( \
113 static_cast<ctype>(x), \
114 static_cast<ctype>(y), \
115 static_cast<ctype>(z), \
118 METAL_FUNC otype fmin(itype x, itype y) { \
119 return static_cast<otype>( \
120 __metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
122 METAL_FUNC otype fmin3(itype x, itype y, itype z) { \
123 return static_cast<otype>(__metal_fmin3( \
124 static_cast<ctype>(x), \
125 static_cast<ctype>(y), \
126 static_cast<ctype>(z), \
129 METAL_FUNC otype fmod(itype x, itype y) { \
130 return static_cast<otype>( \
131 __metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
133 METAL_FUNC otype fract(itype x) { \
134 return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \
136 METAL_FUNC otype frexp(itype x, thread int& exp) { \
137 return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \
139 METAL_FUNC otype ldexp(itype x, int k) { \
140 return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \
142 METAL_FUNC otype log(itype x) { \
143 return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \
145 METAL_FUNC otype log10(itype x) { \
146 return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \
148 METAL_FUNC otype log2(itype x) { \
149 return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \
151 METAL_FUNC otype max(itype x, itype y) { \
152 return static_cast<otype>( \
153 __metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
155 METAL_FUNC otype max3(itype x, itype y, itype z) { \
156 return static_cast<otype>(__metal_fmax3( \
157 static_cast<ctype>(x), \
158 static_cast<ctype>(y), \
159 static_cast<ctype>(z), \
162 METAL_FUNC otype median3(itype x, itype y, itype z) { \
163 return static_cast<otype>(__metal_fmedian3( \
164 static_cast<ctype>(x), \
165 static_cast<ctype>(y), \
166 static_cast<ctype>(z), \
169 METAL_FUNC otype min(itype x, itype y) { \
170 return static_cast<otype>( \
171 __metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
173 METAL_FUNC otype min3(itype x, itype y, itype z) { \
174 return static_cast<otype>(__metal_fmin3( \
175 static_cast<ctype>(x), \
176 static_cast<ctype>(y), \
177 static_cast<ctype>(z), \
180 METAL_FUNC otype nextafter(itype x, itype y) { \
181 return static_cast<otype>( \
182 __metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \
184 METAL_FUNC otype pow(itype x, itype y) { \
185 return static_cast<otype>( \
186 __metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
188 METAL_FUNC otype powr(itype x, itype y) { \
189 return static_cast<otype>( \
190 __metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
192 METAL_FUNC otype rint(itype x) { \
193 return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \
195 METAL_FUNC otype round(itype x) { \
196 return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \
198 METAL_FUNC otype rsqrt(itype x) { \
199 return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \
201 METAL_FUNC otype sin(itype x) { \
202 return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \
204 METAL_FUNC otype sinh(itype x) { \
205 return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \
207 METAL_FUNC otype sinpi(itype x) { \
208 return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \
210 METAL_FUNC otype sqrt(itype x) { \
211 return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \
213 METAL_FUNC otype tan(itype x) { \
214 return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \
216 METAL_FUNC otype tanh(itype x) { \
217 return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \
219 METAL_FUNC otype tanpi(itype x) { \
220 return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \
222 METAL_FUNC otype trunc(itype x) { \
223 return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \
232 __METAL_MAYBE_FAST_MATH__);
240 __METAL_FAST_MATH__);
250 __METAL_PRECISE_MATH__);
260#define instantiate_metal_simd_comm_funcs( \
261 itype, otype, ctype, itype_to_ctype, ctype_to_otype) \
263 METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \
264 return ctype_to_otype( \
265 __metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \
268 METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \
269 return ctype_to_otype( \
270 __metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \
273 METAL_FUNC otype simd_shuffle_and_fill_down( \
274 itype data, itype filling_data, ushort delta, ushort modulo) { \
275 return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
276 itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
279 METAL_FUNC otype simd_shuffle_and_fill_down( \
280 itype data, itype filling_data, ushort delta) { \
281 return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
282 itype_to_ctype(data), \
283 itype_to_ctype(filling_data), \
285 __metal_get_simdgroup_size(ushort()))); \
288 METAL_FUNC otype simd_shuffle_and_fill_up( \
289 itype data, itype filling_data, ushort delta, ushort modulo) { \
290 return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
291 itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
294 METAL_FUNC otype simd_shuffle_and_fill_up( \
295 itype data, itype filling_data, ushort delta) { \
296 return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
297 itype_to_ctype(data), \
298 itype_to_ctype(filling_data), \
300 __metal_get_simdgroup_size(ushort()))); \
303 METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \
304 return ctype_to_otype( \
305 __metal_simd_shuffle_down(itype_to_ctype(data), delta)); \
308 METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \
309 return ctype_to_otype( \
310 __metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \
313 METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \
314 return ctype_to_otype( \
315 __metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \
318 METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \
319 return ctype_to_otype( \
320 __metal_simd_shuffle_up(itype_to_ctype(data), delta)); \
323 METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \
324 return ctype_to_otype( \
325 __metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \
328#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \
330 METAL_FUNC otype simd_max(itype data) { \
331 return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \
334 METAL_FUNC otype simd_min(itype data) { \
335 return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \
338 METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \
339 return static_cast<otype>( \
340 __metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \
343 METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \
344 return static_cast<otype>( \
345 __metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \
348 METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \
349 return static_cast<otype>( \
350 __metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \
353 METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \
354 return static_cast<otype>( \
355 __metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \
358 METAL_FUNC otype simd_product(itype data) { \
359 return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \
362 METAL_FUNC otype simd_sum(itype data) { \
363 return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \
366 METAL_FUNC otype simd_xor(itype data) { \
367 return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype)
Definition bf16_math.h:328
#define instantiate_metal_math_funcs(itype, otype, ctype, mfast)
Definition bf16_math.h:33
#define instantiate_metal_simd_comm_funcs( itype, otype, ctype, itype_to_ctype, ctype_to_otype)
Definition bf16_math.h:260