35#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \
37 METAL_FUNC otype abs(itype x) { \
38 return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
40 METAL_FUNC otype acos(itype x) { \
41 return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \
43 METAL_FUNC otype acosh(itype x) { \
44 return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \
46 METAL_FUNC otype asin(itype x) { \
47 return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \
49 METAL_FUNC otype asinh(itype x) { \
50 return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \
52 METAL_FUNC otype atan(itype y_over_x) { \
53 return static_cast<otype>( \
54 __metal_atan(static_cast<ctype>(y_over_x), mfast)); \
56 METAL_FUNC otype atan2(itype y, itype x) { \
57 return static_cast<otype>( \
58 __metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \
60 METAL_FUNC otype atanh(itype x) { \
61 return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \
63 METAL_FUNC otype ceil(itype x) { \
64 return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \
66 METAL_FUNC otype cos(itype x) { \
67 return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \
69 METAL_FUNC otype cosh(itype x) { \
70 return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \
72 METAL_FUNC otype cospi(itype x) { \
73 return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \
75 METAL_FUNC otype divide(itype x, itype y) { \
76 return static_cast<otype>( \
77 __metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
79 METAL_FUNC otype exp(itype x) { \
80 return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \
82 METAL_FUNC otype exp10(itype x) { \
83 return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \
85 METAL_FUNC otype exp2(itype x) { \
86 return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \
88 METAL_FUNC otype fabs(itype x) { \
89 return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
91 METAL_FUNC otype fdim(itype x, itype y) { \
92 ctype t = static_cast<ctype>(x - y); \
93 return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \
95 METAL_FUNC otype floor(itype x) { \
96 return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \
98 METAL_FUNC otype fma(itype x, itype y, itype z) { \
99 return static_cast<otype>(__metal_fma( \
100 static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \
102 METAL_FUNC otype fmax(itype x, itype y) { \
103 return static_cast<otype>( \
104 __metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
106 METAL_FUNC otype fmax3(itype x, itype y, itype z) { \
107 return static_cast<otype>(__metal_fmax3( \
108 static_cast<ctype>(x), \
109 static_cast<ctype>(y), \
110 static_cast<ctype>(z), \
113 METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \
114 return static_cast<otype>(__metal_fmedian3( \
115 static_cast<ctype>(x), \
116 static_cast<ctype>(y), \
117 static_cast<ctype>(z), \
120 METAL_FUNC otype fmin(itype x, itype y) { \
121 return static_cast<otype>( \
122 __metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
124 METAL_FUNC otype fmin3(itype x, itype y, itype z) { \
125 return static_cast<otype>(__metal_fmin3( \
126 static_cast<ctype>(x), \
127 static_cast<ctype>(y), \
128 static_cast<ctype>(z), \
131 METAL_FUNC otype fmod(itype x, itype y) { \
132 return static_cast<otype>( \
133 __metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
135 METAL_FUNC otype fract(itype x) { \
136 return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \
138 METAL_FUNC otype frexp(itype x, thread int& exp) { \
139 return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \
141 METAL_FUNC otype ldexp(itype x, int k) { \
142 return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \
144 METAL_FUNC otype log(itype x) { \
145 return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \
147 METAL_FUNC otype log10(itype x) { \
148 return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \
150 METAL_FUNC otype log2(itype x) { \
151 return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \
153 METAL_FUNC otype max(itype x, itype y) { \
154 return static_cast<otype>( \
155 __metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
157 METAL_FUNC otype max3(itype x, itype y, itype z) { \
158 return static_cast<otype>(__metal_fmax3( \
159 static_cast<ctype>(x), \
160 static_cast<ctype>(y), \
161 static_cast<ctype>(z), \
164 METAL_FUNC otype median3(itype x, itype y, itype z) { \
165 return static_cast<otype>(__metal_fmedian3( \
166 static_cast<ctype>(x), \
167 static_cast<ctype>(y), \
168 static_cast<ctype>(z), \
171 METAL_FUNC otype min(itype x, itype y) { \
172 return static_cast<otype>( \
173 __metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
175 METAL_FUNC otype min3(itype x, itype y, itype z) { \
176 return static_cast<otype>(__metal_fmin3( \
177 static_cast<ctype>(x), \
178 static_cast<ctype>(y), \
179 static_cast<ctype>(z), \
182 METAL_FUNC otype nextafter(itype x, itype y) { \
183 return static_cast<otype>( \
184 __metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \
186 METAL_FUNC otype pow(itype x, itype y) { \
187 return static_cast<otype>( \
188 __metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
190 METAL_FUNC otype powr(itype x, itype y) { \
191 return static_cast<otype>( \
192 __metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
194 METAL_FUNC otype rint(itype x) { \
195 return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \
197 METAL_FUNC otype round(itype x) { \
198 return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \
200 METAL_FUNC otype rsqrt(itype x) { \
201 return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \
203 METAL_FUNC otype sin(itype x) { \
204 return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \
206 METAL_FUNC otype sinh(itype x) { \
207 return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \
209 METAL_FUNC otype sinpi(itype x) { \
210 return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \
212 METAL_FUNC otype sqrt(itype x) { \
213 return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \
215 METAL_FUNC otype tan(itype x) { \
216 return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \
218 METAL_FUNC otype tanh(itype x) { \
219 return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \
221 METAL_FUNC otype tanpi(itype x) { \
222 return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \
224 METAL_FUNC otype trunc(itype x) { \
225 return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \
234 __METAL_MAYBE_FAST_MATH__);
242 __METAL_FAST_MATH__);
252 __METAL_PRECISE_MATH__);
262#define instantiate_metal_simd_comm_funcs( \
263 itype, otype, ctype, itype_to_ctype, ctype_to_otype) \
265 METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \
266 return ctype_to_otype( \
267 __metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \
270 METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \
271 return ctype_to_otype( \
272 __metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \
275 METAL_FUNC otype simd_shuffle_and_fill_down( \
276 itype data, itype filling_data, ushort delta, ushort modulo) { \
277 return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
278 itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
281 METAL_FUNC otype simd_shuffle_and_fill_down( \
282 itype data, itype filling_data, ushort delta) { \
283 return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
284 itype_to_ctype(data), \
285 itype_to_ctype(filling_data), \
287 __metal_get_simdgroup_size(ushort()))); \
290 METAL_FUNC otype simd_shuffle_and_fill_up( \
291 itype data, itype filling_data, ushort delta, ushort modulo) { \
292 return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
293 itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
296 METAL_FUNC otype simd_shuffle_and_fill_up( \
297 itype data, itype filling_data, ushort delta) { \
298 return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
299 itype_to_ctype(data), \
300 itype_to_ctype(filling_data), \
302 __metal_get_simdgroup_size(ushort()))); \
305 METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \
306 return ctype_to_otype( \
307 __metal_simd_shuffle_down(itype_to_ctype(data), delta)); \
310 METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \
311 return ctype_to_otype( \
312 __metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \
315 METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \
316 return ctype_to_otype( \
317 __metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \
320 METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \
321 return ctype_to_otype( \
322 __metal_simd_shuffle_up(itype_to_ctype(data), delta)); \
325 METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \
326 return ctype_to_otype( \
327 __metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \
330#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \
332 METAL_FUNC otype simd_max(itype data) { \
333 return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \
336 METAL_FUNC otype simd_min(itype data) { \
337 return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \
340 METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \
341 return static_cast<otype>( \
342 __metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \
345 METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \
346 return static_cast<otype>( \
347 __metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \
350 METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \
351 return static_cast<otype>( \
352 __metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \
355 METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \
356 return static_cast<otype>( \
357 __metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \
360 METAL_FUNC otype simd_product(itype data) { \
361 return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \
364 METAL_FUNC otype simd_sum(itype data) { \
365 return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \
368 METAL_FUNC otype simd_xor(itype data) { \
369 return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
372#if defined METAL_3_1 || (__METAL_VERSION__ >= 310)
374#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
375#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
379#define bfloat16_to_uint16(x) x.bits_
380#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
#define uint16_to_bfloat16(x)
Definition bf16_math.h:380
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype)
Definition bf16_math.h:330
#define bfloat16_to_uint16(x)
Definition bf16_math.h:379
#define instantiate_metal_math_funcs(itype, otype, ctype, mfast)
Definition bf16_math.h:35
#define instantiate_metal_simd_comm_funcs( itype, otype, ctype, itype_to_ctype, ctype_to_otype)
Definition bf16_math.h:262