MLX
Loading...
Searching...
No Matches
bf16.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include <metal_stdlib>
6
7using namespace metal;
8
9#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
10
11typedef bfloat bfloat16_t;
12
13#else
14
16// Helpers
18
19constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
20 // Check for nan
21 if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
22 _fp_encoding_traits<float>::inf_mask) {
23 return uint16_t(as_type<uint32_t>(0x7FC0));
24 }
25 // Take bits
26 uint32_t float_bits = as_type<uint32_t>(x);
27
28 // Round to nearest even
29 float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
30
31 // Take upper 16 bits
32 return float_bits >> 16;
33}
34
35constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
36 // Upper 16 bits are the data and lower 16 bits are 0s
37 return as_type<float>((uint32_t)x << 16);
38}
39
40struct _MLX_BFloat16;
41
42template <typename T>
43static constexpr constant bool can_convert_to_bfloat =
44 !is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
45
46template <typename T>
47static constexpr constant bool can_convert_from_bfloat =
48 !is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
49
51// Bfloat struct
53
56 // Constructors
57 uint16_t bits_;
58 _MLX_BFloat16() thread = default;
59 _MLX_BFloat16() threadgroup = default;
60 _MLX_BFloat16() device = default;
61 _MLX_BFloat16() constant = default;
62
64 static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
65 return bits_to_bfloat_struct();
66 }
67 constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
68 : bits_(bits) {}
69
71 // Conversions to bfloat
72
73 template <
74 typename T,
75 typename = typename enable_if<can_convert_to_bfloat<T>>::type>
76 constexpr METAL_FUNC _MLX_BFloat16(T x) thread
77 : bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
78
79 template <
80 typename T,
81 typename = typename enable_if<can_convert_to_bfloat<T>>::type>
82 constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
83 : bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
84
85 template <
86 typename T,
87 typename = typename enable_if<can_convert_to_bfloat<T>>::type>
88 constexpr METAL_FUNC _MLX_BFloat16(T x) device
89 : bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
90
91 template <
92 typename T,
93 typename = typename enable_if<can_convert_to_bfloat<T>>::type>
94 constexpr METAL_FUNC _MLX_BFloat16(T x) constant
95 : bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
96
98 // Conversions from bfloat
99
100 template <
101 typename T,
102 typename = typename enable_if<can_convert_from_bfloat<T>>::type>
103 constexpr METAL_FUNC operator T() const thread {
104 return static_cast<T>(bfloat_bits_to_float(bits_));
105 }
106
107 template <
108 typename T,
109 typename = typename enable_if<can_convert_from_bfloat<T>>::type>
110 constexpr METAL_FUNC operator T() const threadgroup {
111 return static_cast<T>(bfloat_bits_to_float(bits_));
112 }
113
114 template <
115 typename T,
116 typename = typename enable_if<can_convert_from_bfloat<T>>::type>
117 constexpr METAL_FUNC operator T() const device {
118 return static_cast<T>(bfloat_bits_to_float(bits_));
119 }
120
121 template <
122 typename T,
123 typename = typename enable_if<can_convert_from_bfloat<T>>::type>
124 constexpr METAL_FUNC operator T() const constant {
125 return static_cast<T>(bfloat_bits_to_float(bits_));
126 }
127};
128
130// Bfloat operators
132
134// Unary ops
135constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
136 return -static_cast<float>(x);
137}
138
140// Binary operators
141#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
142 constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
143 return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
144 }
145
146#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
147 constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
148 return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
149 } \
150 constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
151 return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
152 }
153
155// Arithmetic Operators
156#define bfloat_binop(_op_, _operator_) \
157 bfloat_binop_base( \
158 _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
159 bfloat_binop_helper(_op_, _operator_, float, float, float); \
160 bfloat_binop_helper(_op_, _operator_, float, half, float); \
161 bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
162 bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
163 bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
164 bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
165
166bfloat_binop(+, operator+);
167bfloat_binop(-, operator-);
168bfloat_binop(*, operator*);
169bfloat_binop(/, operator/);
170
172// Comparison ops
173#define bfloat_compop(__op__, __operator__) \
174 bfloat_binop_base( \
175 __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
176 bfloat_binop_helper(__op__, __operator__, bool, float, float); \
177 bfloat_binop_helper(__op__, __operator__, bool, half, float); \
178 bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
179 bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
180 bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
181 bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
182
183bfloat_compop(>, operator>);
184bfloat_compop(<, operator<);
185bfloat_compop(>=, operator>=);
186bfloat_compop(<=, operator<=);
187bfloat_compop(==, operator==);
188bfloat_compop(!=, operator!=);
189
190#undef bfloat_compop
191#undef bfloat_binop_base
192#undef bfloat_binop_helper
193#undef bfloat_binop
194
196// Inplace Operators
197#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \
198 constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
199 addr_space _MLX_BFloat16& lhs, itype rhs) { \
200 lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
201 return lhs; \
202 } \
203 constexpr METAL_FUNC addr_space itype& __operator__( \
204 addr_space itype& lhs, _MLX_BFloat16 rhs) { \
205 lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
206 return lhs; \
207 }
208
209#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \
210 bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
211 bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
212 bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
213
214#define bfloat_inplace_op(itype) \
215 bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
216 bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
217 bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
218 bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
219
228
229#undef bfloat_inplace_op_helper
230#undef bfloat_inplace_op_addr_space_helper
231#undef bfloat_inplace_op
232
233#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \
234 constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
235 addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
236 lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
237 return lhs; \
238 }
239
240#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \
241 bfloat_inplace_op_helper(__op__, __operator__, device); \
242 bfloat_inplace_op_helper(__op__, __operator__, thread); \
243 bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
244
249
250#undef bfloat_inplace_op_helper
251#undef bfloat_inplace_op_addr_space_helper
252
254// Bfloat typedef
256
258
260// Bfloat numeric limits
262
263#pragma METAL internals : enable
264
265namespace metal {
266
267template <>
268struct _numeric_limits_impl<bfloat16_t> : _fp_numeric_limits_impl_base {
269 static constexpr constant int digits = 8;
270 static constexpr constant int digits10 = 2;
271 static constexpr constant int max_digits10 = 4;
272 static constexpr constant int radix = 2;
273 static constexpr constant int min_exponent = -125;
274 static constexpr constant int min_exponent10 = -37;
275 static constexpr constant int max_exponent = 128;
276 static constexpr constant int max_exponent10 = 38;
277
278 static constexpr bfloat16_t min() {
280 }
281 static constexpr bfloat16_t lowest() {
283 }
284 static constexpr bfloat16_t max() {
286 }
287 static constexpr bfloat16_t epsilon() {
289 }
290 static constexpr bfloat16_t round_error() {
292 }
293 static constexpr bfloat16_t infinity() {
295 }
296 static constexpr bfloat16_t quiet_NaN() {
298 }
299 static constexpr bfloat16_t signaling_NaN() {
301 }
302 static constexpr bfloat16_t denorm_min() {
304 }
305};
306
307METAL_FUNC bool isnan(_MLX_BFloat16 x) {
308 return x != x;
309}
310
311} // namespace metal
312
313#pragma METAL internals : disable
314
315#endif
316
constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x)
Definition bf16.h:19
#define bfloat_compop(__op__, __operator__)
Definition bf16.h:173
constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x)
Definition bf16.h:35
#define bfloat_inplace_op(itype)
Definition bf16.h:214
constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x)
Definition bf16.h:135
#define bfloat_binop(_op_, _operator_)
Definition bf16.h:156
struct _MLX_BFloat16 bfloat16_t
Definition bf16.h:257
static constexpr constant bool can_convert_from_bfloat
Definition bf16.h:47
static constexpr constant bool can_convert_to_bfloat
Definition bf16.h:43
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype)
Definition bf16.h:209
Definition bf16.h:265
METAL_FUNC bool isnan(_MLX_BFloat16 x)
Definition bf16.h:307
Definition bf16.h:54
constexpr METAL_FUNC _MLX_BFloat16(T x) thread
Definition bf16.h:76
uint16_t bits_
Definition bf16.h:57
constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
Definition bf16.h:67
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat()
Definition bf16.h:64
_MLX_BFloat16() thread=default
constexpr METAL_FUNC _MLX_BFloat16(T x) device
Definition bf16.h:88
constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
Definition bf16.h:82
constexpr METAL_FUNC _MLX_BFloat16(T x) const ant
Definition bf16.h:94
static constexpr bfloat16_t infinity()
Definition bf16.h:293
static constexpr bfloat16_t denorm_min()
Definition bf16.h:302
static constexpr bfloat16_t max()
Definition bf16.h:284
static constexpr bfloat16_t epsilon()
Definition bf16.h:287
static constexpr bfloat16_t signaling_NaN()
Definition bf16.h:299
static constexpr bfloat16_t min()
Definition bf16.h:278
static constexpr bfloat16_t lowest()
Definition bf16.h:281
static constexpr bfloat16_t quiet_NaN()
Definition bf16.h:296
static constexpr bfloat16_t round_error()
Definition bf16.h:290