9#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
21 if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
22 _fp_encoding_traits<float>::inf_mask) {
23 return uint16_t(as_type<uint32_t>(0x7FC0));
26 uint32_t float_bits = as_type<uint32_t>(x);
29 float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
32 return float_bits >> 16;
37 return as_type<float>((uint32_t)x << 16);
44 !is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
48 !is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
75 typename =
typename enable_if<can_convert_to_bfloat<T>>::type>
81 typename =
typename enable_if<can_convert_to_bfloat<T>>::type>
87 typename =
typename enable_if<can_convert_to_bfloat<T>>::type>
93 typename =
typename enable_if<can_convert_to_bfloat<T>>::type>
102 typename =
typename enable_if<can_convert_from_bfloat<T>>::type>
103 constexpr METAL_FUNC
operator T() const thread {
109 typename =
typename enable_if<can_convert_from_bfloat<T>>::type>
110 constexpr METAL_FUNC
operator T() const threadgroup {
116 typename =
typename enable_if<can_convert_from_bfloat<T>>::type>
117 constexpr METAL_FUNC
operator T() const device {
123 typename =
typename enable_if<can_convert_from_bfloat<T>>::type>
124 constexpr METAL_FUNC
operator T() const constant {
136 return -
static_cast<float>(x);
141#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
142 constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
143 return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
146#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
147 constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
148 return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
150 constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
151 return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
156#define bfloat_binop(_op_, _operator_) \
158 _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
159 bfloat_binop_helper(_op_, _operator_, float, float, float); \
160 bfloat_binop_helper(_op_, _operator_, float, half, float); \
161 bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
162 bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
163 bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
164 bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
173#define bfloat_compop(__op__, __operator__) \
175 __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
176 bfloat_binop_helper(__op__, __operator__, bool, float, float); \
177 bfloat_binop_helper(__op__, __operator__, bool, half, float); \
178 bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
179 bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
180 bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
181 bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
191#undef bfloat_binop_base
192#undef bfloat_binop_helper
197#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \
198 constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
199 addr_space _MLX_BFloat16& lhs, itype rhs) { \
200 lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
203 constexpr METAL_FUNC addr_space itype& __operator__( \
204 addr_space itype& lhs, _MLX_BFloat16 rhs) { \
205 lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
209#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \
210 bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
211 bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
212 bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
214#define bfloat_inplace_op(itype) \
215 bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
216 bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
217 bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
218 bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
229#undef bfloat_inplace_op_helper
230#undef bfloat_inplace_op_addr_space_helper
231#undef bfloat_inplace_op
233#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \
234 constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
235 addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
236 lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
240#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \
241 bfloat_inplace_op_helper(__op__, __operator__, device); \
242 bfloat_inplace_op_helper(__op__, __operator__, thread); \
243 bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
250#undef bfloat_inplace_op_helper
251#undef bfloat_inplace_op_addr_space_helper
263#pragma METAL internals : enable
268struct _numeric_limits_impl<
bfloat16_t> : _fp_numeric_limits_impl_base {
269 static constexpr constant
int digits = 8;
270 static constexpr constant
int digits10 = 2;
271 static constexpr constant
int max_digits10 = 4;
272 static constexpr constant
int radix = 2;
273 static constexpr constant
int min_exponent = -125;
274 static constexpr constant
int min_exponent10 = -37;
275 static constexpr constant
int max_exponent = 128;
276 static constexpr constant
int max_exponent10 = 38;
313#pragma METAL internals : disable
constexpr METAL_FUNC _MLX_BFloat16(T x) thread
Definition bf16.h:76
uint16_t bits_
Definition bf16.h:57
constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
Definition bf16.h:67
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat()
Definition bf16.h:64
_MLX_BFloat16() thread=default
constexpr METAL_FUNC _MLX_BFloat16(T x) device
Definition bf16.h:88
constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
Definition bf16.h:82
constexpr METAL_FUNC _MLX_BFloat16(T x) const ant
Definition bf16.h:94