9#if defined(__HAVE_BFLOAT__) 
   21  if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
 
   22      _fp_encoding_traits<float>::inf_mask) {
 
   23    return uint16_t(as_type<uint32_t>(0x7FC0));
 
   26  uint32_t float_bits = as_type<uint32_t>(x);
 
   29  float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
 
   32  return float_bits >> 16;
 
 
   37  return as_type<float>((uint32_t)x << 16);
 
 
   44    !is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
 
   48    !is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
 
   75      typename = 
typename enable_if<can_convert_to_bfloat<T>>::type>
 
   81      typename = 
typename enable_if<can_convert_to_bfloat<T>>::type>
 
   87      typename = 
typename enable_if<can_convert_to_bfloat<T>>::type>
 
   93      typename = 
typename enable_if<can_convert_to_bfloat<T>>::type>
 
  102      typename = 
typename enable_if<can_convert_from_bfloat<T>>::type>
 
  103  constexpr METAL_FUNC 
operator T() const thread {
 
 
  109      typename = 
typename enable_if<can_convert_from_bfloat<T>>::type>
 
  110  constexpr METAL_FUNC 
operator T() const threadgroup {
 
 
  116      typename = 
typename enable_if<can_convert_from_bfloat<T>>::type>
 
  117  constexpr METAL_FUNC 
operator T() const device {
 
 
  123      typename = 
typename enable_if<can_convert_from_bfloat<T>>::type>
 
  124  constexpr METAL_FUNC 
operator T() const constant {
 
 
 
  136  return -
static_cast<float>(x);
 
 
  141#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ 
  142  constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) {           \ 
  143    return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs);          \ 
 
  146#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype)    \ 
  147  constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ 
  148    return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs);        \ 
  150  constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ 
  151    return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs);        \ 
 
  156#define bfloat_binop(_op_, _operator_)                                       \ 
  158      _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \ 
  159  bfloat_binop_helper(_op_, _operator_, float, float, float);                \ 
  160  bfloat_binop_helper(_op_, _operator_, float, half, float);                 \ 
  161  bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float);      \ 
  162  bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float);     \ 
  163  bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float);      \ 
  164  bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); 
 
  173#define bfloat_compop(__op__, __operator__)                             \ 
  175      __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \ 
  176  bfloat_binop_helper(__op__, __operator__, bool, float, float);        \ 
  177  bfloat_binop_helper(__op__, __operator__, bool, half, float);         \ 
  178  bfloat_binop_helper(__op__, __operator__, bool, int32_t, float);      \ 
  179  bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float);     \ 
  180  bfloat_binop_helper(__op__, __operator__, bool, int64_t, float);      \ 
  181  bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); 
 
  191#undef bfloat_binop_base 
  192#undef bfloat_binop_helper 
  197#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \ 
  198  constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__(            \ 
  199      addr_space _MLX_BFloat16& lhs, itype rhs) {                         \ 
  200    lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs);         \ 
  203  constexpr METAL_FUNC addr_space itype& __operator__(                    \ 
  204      addr_space itype& lhs, _MLX_BFloat16 rhs) {                         \ 
  205    lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs);         \ 
 
  209#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \ 
  210  bfloat_inplace_op_helper(__op__, __operator__, itype, device);         \ 
  211  bfloat_inplace_op_helper(__op__, __operator__, itype, thread);         \ 
  212  bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup); 
 
  214#define bfloat_inplace_op(itype)                             \ 
  215  bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \ 
  216  bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \ 
  217  bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \ 
  218  bfloat_inplace_op_addr_space_helper(/, operator/=, itype); 
 
  229#undef bfloat_inplace_op_helper 
  230#undef bfloat_inplace_op_addr_space_helper 
  231#undef bfloat_inplace_op 
  233#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \ 
  234  constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__(     \ 
  235      addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) {          \ 
  236    lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs);  \ 
  240#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \ 
  241  bfloat_inplace_op_helper(__op__, __operator__, device);         \ 
  242  bfloat_inplace_op_helper(__op__, __operator__, thread);         \ 
  243  bfloat_inplace_op_helper(__op__, __operator__, threadgroup); 
  250#undef bfloat_inplace_op_helper 
  251#undef bfloat_inplace_op_addr_space_helper 
  263#pragma METAL internals : enable 
  268struct _numeric_limits_impl<
bfloat16_t> : _fp_numeric_limits_impl_base {
 
  269  static constexpr constant 
int digits = 8;
 
  270  static constexpr constant 
int digits10 = 2;
 
  271  static constexpr constant 
int max_digits10 = 4;
 
  272  static constexpr constant 
int radix = 2;
 
  273  static constexpr constant 
int min_exponent = -125;
 
  274  static constexpr constant 
int min_exponent10 = -37;
 
  275  static constexpr constant 
int max_exponent = 128;
 
  276  static constexpr constant 
int max_exponent10 = 38;
 
 
 
  313#pragma METAL internals : disable 
constexpr METAL_FUNC _MLX_BFloat16(T x) thread
Definition bf16.h:76
 
uint16_t bits_
Definition bf16.h:57
 
constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
Definition bf16.h:67
 
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat()
Definition bf16.h:64
 
_MLX_BFloat16() thread=default
 
constexpr METAL_FUNC _MLX_BFloat16(T x) device
Definition bf16.h:88
 
constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
Definition bf16.h:82
 
constexpr METAL_FUNC _MLX_BFloat16(T x) const ant
Definition bf16.h:94