MLX
Loading...
Searching...
No Matches
Namespaces | Macros | Functions
bf16_math.h File Reference
#include "mlx/backend/metal/kernels/bf16.h"

Go to the source code of this file.

Namespaces

namespace  metal
 
namespace  metal::fast
 
namespace  metal::precise
 

Macros

#define instantiate_metal_math_funcs(itype, otype, ctype, mfast)
 
#define instantiate_metal_simd_comm_funcs( itype, otype, ctype, itype_to_ctype, ctype_to_otype)
 
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype)
 
#define bfloat16_to_uint16(x)   x.bits_
 
#define uint16_to_bfloat16(x)   _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
 

Functions

METAL_FUNC bfloat16_t metal::abs (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::acos (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::acosh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::asin (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::asinh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::atan (bfloat16_t y_over_x)
 
METAL_FUNC bfloat16_t metal::atan2 (bfloat16_t y, bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::atanh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::ceil (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::cos (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::cosh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::cospi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::divide (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::exp (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::exp10 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::exp2 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fabs (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fdim (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::floor (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fma (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fmax (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fmax3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fmedian3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fmin (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fmin3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fmod (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fract (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::frexp (bfloat16_t x, thread int &exp)
 
METAL_FUNC bfloat16_t metal::ldexp (bfloat16_t x, int k)
 
METAL_FUNC bfloat16_t metal::log (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::log10 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::log2 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::max (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::max3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::median3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::min (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::min3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::nextafter (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::pow (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::powr (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::rint (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::round (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::rsqrt (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::sin (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::sinh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::sinpi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::sqrt (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::tan (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::tanh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::tanpi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::trunc (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::abs (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::acos (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::acosh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::asin (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::asinh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::atan (bfloat16_t y_over_x)
 
METAL_FUNC bfloat16_t metal::fast::atan2 (bfloat16_t y, bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::atanh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::ceil (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::cos (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::cosh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::cospi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::divide (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::exp (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::exp10 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::exp2 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::fabs (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::fdim (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::floor (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::fma (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fast::fmax (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::fmax3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fast::fmedian3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fast::fmin (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::fmin3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fast::fmod (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::fract (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::frexp (bfloat16_t x, thread int &exp)
 
METAL_FUNC bfloat16_t metal::fast::ldexp (bfloat16_t x, int k)
 
METAL_FUNC bfloat16_t metal::fast::log (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::log10 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::log2 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::max (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::max3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fast::median3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fast::min (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::min3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::fast::nextafter (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::pow (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::powr (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::fast::rint (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::round (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::rsqrt (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::sin (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::sinh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::sinpi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::sqrt (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::tan (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::tanh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::tanpi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::fast::trunc (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::abs (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::acos (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::acosh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::asin (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::asinh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::atan (bfloat16_t y_over_x)
 
METAL_FUNC bfloat16_t metal::precise::atan2 (bfloat16_t y, bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::atanh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::ceil (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::cos (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::cosh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::cospi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::divide (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::exp (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::exp10 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::exp2 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::fabs (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::fdim (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::floor (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::fma (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::precise::fmax (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::fmax3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::precise::fmedian3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::precise::fmin (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::fmin3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::precise::fmod (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::fract (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::frexp (bfloat16_t x, thread int &exp)
 
METAL_FUNC bfloat16_t metal::precise::ldexp (bfloat16_t x, int k)
 
METAL_FUNC bfloat16_t metal::precise::log (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::log10 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::log2 (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::max (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::max3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::precise::median3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::precise::min (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::min3 (bfloat16_t x, bfloat16_t y, bfloat16_t z)
 
METAL_FUNC bfloat16_t metal::precise::nextafter (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::pow (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::powr (bfloat16_t x, bfloat16_t y)
 
METAL_FUNC bfloat16_t metal::precise::rint (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::round (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::rsqrt (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::sin (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::sinh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::sinpi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::sqrt (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::tan (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::tanh (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::tanpi (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::precise::trunc (bfloat16_t x)
 
METAL_FUNC bfloat16_t metal::simd_broadcast (bfloat16_t data, ushort broadcast_lane_id)
 
METAL_FUNC bfloat16_t metal::simd_shuffle (bfloat16_t data, ushort simd_lane_id)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_and_fill_down (bfloat16_t data, bfloat16_t filling_data, ushort delta, ushort modulo)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_and_fill_down (bfloat16_t data, bfloat16_t filling_data, ushort delta)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_and_fill_up (bfloat16_t data, bfloat16_t filling_data, ushort delta, ushort modulo)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_and_fill_up (bfloat16_t data, bfloat16_t filling_data, ushort delta)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_down (bfloat16_t data, ushort delta)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_rotate_down (bfloat16_t data, ushort delta)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_rotate_up (bfloat16_t data, ushort delta)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_up (bfloat16_t data, ushort delta)
 
METAL_FUNC bfloat16_t metal::simd_shuffle_xor (bfloat16_t data, ushort mask)
 
METAL_FUNC bfloat16_t metal::simd_max (bfloat16_t data)
 
METAL_FUNC bfloat16_t metal::simd_min (bfloat16_t data)
 
METAL_FUNC bfloat16_t metal::simd_prefix_exclusive_product (bfloat16_t data)
 
METAL_FUNC bfloat16_t metal::simd_prefix_exclusive_sum (bfloat16_t data)
 
METAL_FUNC bfloat16_t metal::simd_prefix_inclusive_product (bfloat16_t data)
 
METAL_FUNC bfloat16_t metal::simd_prefix_inclusive_sum (bfloat16_t data)
 
METAL_FUNC bfloat16_t metal::simd_product (bfloat16_t data)
 
METAL_FUNC bfloat16_t metal::simd_sum (bfloat16_t data)
 
METAL_FUNC bfloat16_t metal::simd_xor (bfloat16_t data)
 

Macro Definition Documentation

◆ bfloat16_to_uint16

#define bfloat16_to_uint16 ( x)    x.bits_

◆ instantiate_metal_math_funcs

#define instantiate_metal_math_funcs ( itype,
otype,
ctype,
mfast )

◆ instantiate_metal_simd_comm_funcs

#define instantiate_metal_simd_comm_funcs ( itype,
otype,
ctype,
itype_to_ctype,
ctype_to_otype )

◆ instantiate_metal_simd_reduction_funcs

#define instantiate_metal_simd_reduction_funcs ( itype,
otype,
ctype )

◆ uint16_to_bfloat16

#define uint16_to_bfloat16 ( x)    _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())