35#define instantiate_metal_math_funcs(itype, otype, ctype, mfast)               \ 
   37  METAL_FUNC otype abs(itype x) {                                              \ 
   38    return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast));     \ 
   40  METAL_FUNC otype acos(itype x) {                                             \ 
   41    return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast));     \ 
   43  METAL_FUNC otype acosh(itype x) {                                            \ 
   44    return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast));    \ 
   46  METAL_FUNC otype asin(itype x) {                                             \ 
   47    return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast));     \ 
   49  METAL_FUNC otype asinh(itype x) {                                            \ 
   50    return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast));    \ 
   52  METAL_FUNC otype atan(itype y_over_x) {                                      \ 
   53    return static_cast<otype>(                                                 \ 
   54        __metal_atan(static_cast<ctype>(y_over_x), mfast));                    \ 
   56  METAL_FUNC otype atan2(itype y, itype x) {                                   \ 
   57    return static_cast<otype>(                                                 \ 
   58        __metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast));   \ 
   60  METAL_FUNC otype atanh(itype x) {                                            \ 
   61    return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast));    \ 
   63  METAL_FUNC otype ceil(itype x) {                                             \ 
   64    return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast));     \ 
   66  METAL_FUNC otype cos(itype x) {                                              \ 
   67    return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast));      \ 
   69  METAL_FUNC otype cosh(itype x) {                                             \ 
   70    return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast));     \ 
   72  METAL_FUNC otype cospi(itype x) {                                            \ 
   73    return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast));    \ 
   75  METAL_FUNC otype divide(itype x, itype y) {                                  \ 
   76    return static_cast<otype>(                                                 \ 
   77        __metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast));  \ 
   79  METAL_FUNC otype exp(itype x) {                                              \ 
   80    return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast));      \ 
   82  METAL_FUNC otype exp10(itype x) {                                            \ 
   83    return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast));    \ 
   85  METAL_FUNC otype exp2(itype x) {                                             \ 
   86    return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast));     \ 
   88  METAL_FUNC otype fabs(itype x) {                                             \ 
   89    return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast));     \ 
   91  METAL_FUNC otype fdim(itype x, itype y) {                                    \ 
   92    ctype t = static_cast<ctype>(x - y);                                       \ 
   93    return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y));    \ 
   95  METAL_FUNC otype floor(itype x) {                                            \ 
   96    return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast));    \ 
   98  METAL_FUNC otype fma(itype x, itype y, itype z) {                            \ 
   99    return static_cast<otype>(__metal_fma(                                     \ 
  100        static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \ 
  102  METAL_FUNC otype fmax(itype x, itype y) {                                    \ 
  103    return static_cast<otype>(                                                 \ 
  104        __metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast));    \ 
  106  METAL_FUNC otype fmax3(itype x, itype y, itype z) {                          \ 
  107    return static_cast<otype>(__metal_fmax3(                                   \ 
  108        static_cast<ctype>(x),                                                 \ 
  109        static_cast<ctype>(y),                                                 \ 
  110        static_cast<ctype>(z),                                                 \ 
  113  METAL_FUNC otype fmedian3(itype x, itype y, itype z) {                       \ 
  114    return static_cast<otype>(__metal_fmedian3(                                \ 
  115        static_cast<ctype>(x),                                                 \ 
  116        static_cast<ctype>(y),                                                 \ 
  117        static_cast<ctype>(z),                                                 \ 
  120  METAL_FUNC otype fmin(itype x, itype y) {                                    \ 
  121    return static_cast<otype>(                                                 \ 
  122        __metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast));    \ 
  124  METAL_FUNC otype fmin3(itype x, itype y, itype z) {                          \ 
  125    return static_cast<otype>(__metal_fmin3(                                   \ 
  126        static_cast<ctype>(x),                                                 \ 
  127        static_cast<ctype>(y),                                                 \ 
  128        static_cast<ctype>(z),                                                 \ 
  131  METAL_FUNC otype fmod(itype x, itype y) {                                    \ 
  132    return static_cast<otype>(                                                 \ 
  133        __metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast));    \ 
  135  METAL_FUNC otype fract(itype x) {                                            \ 
  136    return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast));    \ 
  138  METAL_FUNC otype frexp(itype x, thread int& exp) {                           \ 
  139    return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp));     \ 
  141  METAL_FUNC otype ldexp(itype x, int k) {                                     \ 
  142    return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \ 
  144  METAL_FUNC otype log(itype x) {                                              \ 
  145    return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast));      \ 
  147  METAL_FUNC otype log10(itype x) {                                            \ 
  148    return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast));    \ 
  150  METAL_FUNC otype log2(itype x) {                                             \ 
  151    return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast));     \ 
  153  METAL_FUNC otype max(itype x, itype y) {                                     \ 
  154    return static_cast<otype>(                                                 \ 
  155        __metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast));    \ 
  157  METAL_FUNC otype max3(itype x, itype y, itype z) {                           \ 
  158    return static_cast<otype>(__metal_fmax3(                                   \ 
  159        static_cast<ctype>(x),                                                 \ 
  160        static_cast<ctype>(y),                                                 \ 
  161        static_cast<ctype>(z),                                                 \ 
  164  METAL_FUNC otype median3(itype x, itype y, itype z) {                        \ 
  165    return static_cast<otype>(__metal_fmedian3(                                \ 
  166        static_cast<ctype>(x),                                                 \ 
  167        static_cast<ctype>(y),                                                 \ 
  168        static_cast<ctype>(z),                                                 \ 
  171  METAL_FUNC otype min(itype x, itype y) {                                     \ 
  172    return static_cast<otype>(                                                 \ 
  173        __metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast));    \ 
  175  METAL_FUNC otype min3(itype x, itype y, itype z) {                           \ 
  176    return static_cast<otype>(__metal_fmin3(                                   \ 
  177        static_cast<ctype>(x),                                                 \ 
  178        static_cast<ctype>(y),                                                 \ 
  179        static_cast<ctype>(z),                                                 \ 
  182  METAL_FUNC otype nextafter(itype x, itype y) {                               \ 
  183    return static_cast<otype>(                                                 \ 
  184        __metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y)));      \ 
  186  METAL_FUNC otype pow(itype x, itype y) {                                     \ 
  187    return static_cast<otype>(                                                 \ 
  188        __metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast));     \ 
  190  METAL_FUNC otype powr(itype x, itype y) {                                    \ 
  191    return static_cast<otype>(                                                 \ 
  192        __metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast));    \ 
  194  METAL_FUNC otype rint(itype x) {                                             \ 
  195    return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast));     \ 
  197  METAL_FUNC otype round(itype x) {                                            \ 
  198    return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast));    \ 
  200  METAL_FUNC otype rsqrt(itype x) {                                            \ 
  201    return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast));    \ 
  203  METAL_FUNC otype sin(itype x) {                                              \ 
  204    return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast));      \ 
  206  METAL_FUNC otype sinh(itype x) {                                             \ 
  207    return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast));     \ 
  209  METAL_FUNC otype sinpi(itype x) {                                            \ 
  210    return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast));    \ 
  212  METAL_FUNC otype sqrt(itype x) {                                             \ 
  213    return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast));     \ 
  215  METAL_FUNC otype tan(itype x) {                                              \ 
  216    return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast));      \ 
  218  METAL_FUNC otype tanh(itype x) {                                             \ 
  219    return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast));     \ 
  221  METAL_FUNC otype tanpi(itype x) {                                            \ 
  222    return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast));    \ 
  224  METAL_FUNC otype trunc(itype x) {                                            \ 
  225    return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast));    \ 
 
  234    __METAL_MAYBE_FAST_MATH__);
 
  242    __METAL_FAST_MATH__);
 
 
  252    __METAL_PRECISE_MATH__);
 
 
  262#define instantiate_metal_simd_comm_funcs(                                   \ 
  263    itype, otype, ctype, itype_to_ctype, ctype_to_otype)                     \ 
  265  METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) {    \ 
  266    return ctype_to_otype(                                                   \ 
  267        __metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id));    \ 
  270  METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) {           \ 
  271    return ctype_to_otype(                                                   \ 
  272        __metal_simd_shuffle(itype_to_ctype(data), simd_lane_id));           \ 
  275  METAL_FUNC otype simd_shuffle_and_fill_down(                               \ 
  276      itype data, itype filling_data, ushort delta, ushort modulo) {         \ 
  277    return ctype_to_otype(__metal_simd_shuffle_and_fill_down(                \ 
  278        itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \ 
  281  METAL_FUNC otype simd_shuffle_and_fill_down(                               \ 
  282      itype data, itype filling_data, ushort delta) {                        \ 
  283    return ctype_to_otype(__metal_simd_shuffle_and_fill_down(                \ 
  284        itype_to_ctype(data),                                                \ 
  285        itype_to_ctype(filling_data),                                        \ 
  287        __metal_get_simdgroup_size(ushort())));                              \ 
  290  METAL_FUNC otype simd_shuffle_and_fill_up(                                 \ 
  291      itype data, itype filling_data, ushort delta, ushort modulo) {         \ 
  292    return ctype_to_otype(__metal_simd_shuffle_and_fill_up(                  \ 
  293        itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \ 
  296  METAL_FUNC otype simd_shuffle_and_fill_up(                                 \ 
  297      itype data, itype filling_data, ushort delta) {                        \ 
  298    return ctype_to_otype(__metal_simd_shuffle_and_fill_up(                  \ 
  299        itype_to_ctype(data),                                                \ 
  300        itype_to_ctype(filling_data),                                        \ 
  302        __metal_get_simdgroup_size(ushort())));                              \ 
  305  METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) {             \ 
  306    return ctype_to_otype(                                                   \ 
  307        __metal_simd_shuffle_down(itype_to_ctype(data), delta));             \ 
  310  METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) {      \ 
  311    return ctype_to_otype(                                                   \ 
  312        __metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta));      \ 
  315  METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) {        \ 
  316    return ctype_to_otype(                                                   \ 
  317        __metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta));        \ 
  320  METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) {               \ 
  321    return ctype_to_otype(                                                   \ 
  322        __metal_simd_shuffle_up(itype_to_ctype(data), delta));               \ 
  325  METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) {               \ 
  326    return ctype_to_otype(                                                   \ 
  327        __metal_simd_shuffle_xor(itype_to_ctype(data), mask));               \ 
  330#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype)            \ 
  332  METAL_FUNC otype simd_max(itype data) {                                      \ 
  333    return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data)));     \ 
  336  METAL_FUNC otype simd_min(itype data) {                                      \ 
  337    return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data)));     \ 
  340  METAL_FUNC otype simd_prefix_exclusive_product(itype data) {                 \ 
  341    return static_cast<otype>(                                                 \ 
  342        __metal_simd_prefix_exclusive_product(static_cast<ctype>(data)));      \ 
  345  METAL_FUNC otype simd_prefix_exclusive_sum(itype data) {                     \ 
  346    return static_cast<otype>(                                                 \ 
  347        __metal_simd_prefix_exclusive_sum(static_cast<ctype>(data)));          \ 
  350  METAL_FUNC otype simd_prefix_inclusive_product(itype data) {                 \ 
  351    return static_cast<otype>(                                                 \ 
  352        __metal_simd_prefix_inclusive_product(static_cast<ctype>(data)));      \ 
  355  METAL_FUNC otype simd_prefix_inclusive_sum(itype data) {                     \ 
  356    return static_cast<otype>(                                                 \ 
  357        __metal_simd_prefix_inclusive_sum(static_cast<ctype>(data)));          \ 
  360  METAL_FUNC otype simd_product(itype data) {                                  \ 
  361    return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \ 
  364  METAL_FUNC otype simd_sum(itype data) {                                      \ 
  365    return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data)));     \ 
  368  METAL_FUNC otype simd_xor(itype data) {                                      \ 
  369    return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data)));     \ 
 
  372#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310) 
  374#define bfloat16_to_uint16(x) as_type<uint16_t>(x) 
  375#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x) 
  379#define bfloat16_to_uint16(x) x.bits_ 
  380#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat()) 
#define uint16_to_bfloat16(x)
Definition bf16_math.h:380
 
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype)
Definition bf16_math.h:330
 
#define bfloat16_to_uint16(x)
Definition bf16_math.h:379
 
#define instantiate_metal_math_funcs(itype, otype, ctype, mfast)
Definition bf16_math.h:35
 
#define instantiate_metal_simd_comm_funcs( itype, otype, ctype, itype_to_ctype, ctype_to_otype)
Definition bf16_math.h:262