MLX
Loading...
Searching...
No Matches
fp16.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include <algorithm>
6#include <cmath>
7#include <cstdint>
8#include <vector>
9
10#define __MLX_HALF_NAN__ 0x7D00
11
12namespace mlx::core {
13
14namespace {
15union float_bits_fp16 {
16 float f;
17 uint32_t u;
18};
19} // namespace
20
22 uint16_t bits_;
23
24 // Default constructor
25 _MLX_Float16() = default;
26
27 // Default copy constructor
28 _MLX_Float16(_MLX_Float16 const&) = default;
29
30 // Appease std::vector<bool> for being special
31 _MLX_Float16& operator=(std::vector<bool>::reference x) {
32 bits_ = x;
33 return *this;
34 }
35
36 _MLX_Float16& operator=(const float& x) {
37 return (*this = _MLX_Float16(x));
38 }
39
40 // From float32
41 _MLX_Float16(const float& x) : bits_(0) {
42 // Conversion following
43 // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h
44
45 // Union
46 float_bits_fp16 in;
47
48 // Take fp32 bits
49 in.f = x;
50
51 // Find and take sign bit
52 uint32_t x_sign_32 = in.u & uint32_t(0x80000000);
53 uint16_t x_sign_16 = (x_sign_32 >> 16);
54
55 if (std::isnan(x)) {
56 bits_ = x_sign_16 | uint16_t(__MLX_HALF_NAN__);
57 } else {
58 // Union
59 float_bits_fp16 inf_scale, zero_scale, magic_bits;
60
61 // Find exponent bits and take the max supported by half
62 uint32_t x_expo_32 = in.u & uint32_t(0x7f800000);
63 uint32_t max_expo_32 = uint32_t(0x38800000);
64 x_expo_32 = x_expo_32 < max_expo_32 ? max_expo_32 : x_expo_32;
65 x_expo_32 += uint32_t(15) << 23;
66
67 // Handle scaling to inf as needed
68 inf_scale.u = uint32_t(0x77800000);
69 zero_scale.u = uint32_t(0x08800000);
70
71 // Combine with magic and let addition do rounding
72 magic_bits.u = x_expo_32;
73 magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f;
74
75 // Take the lower 5 bits of the exponent
76 uint32_t x_expo_16 = ((magic_bits.u >> 13) & uint32_t(0x7c00));
77
78 // Collect the lower 12 bits which have the mantissa
79 uint32_t x_mant_16 = magic_bits.u & uint32_t(0x0fff);
80
81 // Combine sign, exp and mantissa
82 bits_ = (x_sign_16 | uint16_t(x_expo_16 + x_mant_16));
83 }
84 }
85
86 // To float32
87 operator float() const {
88 // Conversion following
89 // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h
90
91 // Union
92 float_bits_fp16 out;
93
94 uint32_t x_sign_32 = (bits_ << 16) & uint32_t(0x80000000);
95 uint32_t base = (bits_ << 16);
96 uint32_t two_base = base + base;
97
98 uint32_t denorm_max = 1u << 27;
99 if (two_base < denorm_max) {
100 out.u = uint32_t(126) << 23; // magic mask
101 out.u |= (two_base >> 17); // Bits from fp16
102 out.f -= 0.5f; // magic bias
103 } else {
104 out.u = uint32_t(0xE0) << 23; // exponent offset
105 out.u += (two_base >> 4); // Bits from fp16
106 float out_unscaled = out.f; // Store value
107 out.u = uint32_t(0x7800000); // exponent scale
108 out.f *= out_unscaled;
109 }
110
111 // Add sign
112 out.u |= x_sign_32;
113
114 return out.f;
115 }
116};
117
118#define half_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
119 inline otype __operator__(atype lhs, btype rhs) { \
120 return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
121 }
122
123#define half_binop_helper(__op__, __operator__, otype, itype, ctype) \
124 inline otype __operator__(_MLX_Float16 lhs, itype rhs) { \
125 return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
126 } \
127 inline otype __operator__(itype lhs, _MLX_Float16 rhs) { \
128 return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
129 }
130
131// Operators
132#define half_binop(__op__, __operator__) \
133 half_binop_base( \
134 __op__, __operator__, _MLX_Float16, _MLX_Float16, _MLX_Float16, float); \
135 half_binop_helper(__op__, __operator__, float, float, float); \
136 half_binop_helper(__op__, __operator__, double, double, double); \
137 half_binop_helper(__op__, __operator__, _MLX_Float16, bool, float); \
138 half_binop_helper(__op__, __operator__, _MLX_Float16, int32_t, float); \
139 half_binop_helper(__op__, __operator__, _MLX_Float16, uint32_t, float); \
140 half_binop_helper(__op__, __operator__, _MLX_Float16, int64_t, float); \
141 half_binop_helper(__op__, __operator__, _MLX_Float16, uint64_t, float);
142
143half_binop(+, operator+);
144half_binop(-, operator-);
145half_binop(*, operator*);
146half_binop(/, operator/);
147
148#undef half_binop
149
150// Comparison ops
151#define half_compop(__op__, __operator__) \
152 half_binop_base( \
153 __op__, __operator__, bool, _MLX_Float16, _MLX_Float16, float); \
154 half_binop_helper(__op__, __operator__, bool, float, float); \
155 half_binop_helper(__op__, __operator__, bool, double, double); \
156 half_binop_helper(__op__, __operator__, bool, int32_t, float); \
157 half_binop_helper(__op__, __operator__, bool, uint32_t, float); \
158 half_binop_helper(__op__, __operator__, bool, int64_t, float); \
159 half_binop_helper(__op__, __operator__, bool, uint64_t, float);
160
161half_compop(>, operator>);
162half_compop(<, operator<);
163half_compop(>=, operator>=);
164half_compop(<=, operator<=);
165half_compop(==, operator==);
166half_compop(!=, operator!=);
167
168#undef half_compop
169
170// Negative
172 return -static_cast<float>(lhs);
173}
174
175// Inplace ops
176#define half_inplace_op(__op__, __operator__) \
177 inline _MLX_Float16& __operator__(_MLX_Float16& lhs, const float& rhs) { \
178 lhs = lhs __op__ rhs; \
179 return lhs; \
180 } \
181 inline float& __operator__(float& lhs, _MLX_Float16 rhs) { \
182 lhs = lhs __op__ rhs; \
183 return lhs; \
184 }
185
186half_inplace_op(+, operator+=);
187half_inplace_op(-, operator-=);
188half_inplace_op(*, operator*=);
189half_inplace_op(/, operator/=);
190
191#undef half_inplace_op
192
193// Bitwise ops
194
195#define half_bitop(__op__, __operator__) \
196 inline _MLX_Float16 __operator__(_MLX_Float16 lhs, _MLX_Float16 rhs) { \
197 _MLX_Float16 out; \
198 out.bits_ = lhs.bits_ __op__ rhs.bits_; \
199 return out; \
200 } \
201 inline _MLX_Float16 __operator__(_MLX_Float16 lhs, uint16_t rhs) { \
202 _MLX_Float16 out; \
203 out.bits_ = lhs.bits_ __op__ rhs; \
204 return out; \
205 } \
206 inline _MLX_Float16 __operator__(uint16_t lhs, _MLX_Float16 rhs) { \
207 _MLX_Float16 out; \
208 out.bits_ = lhs __op__ rhs.bits_; \
209 return out; \
210 }
211
212half_bitop(|, operator|);
213half_bitop(&, operator&);
214half_bitop(^, operator^);
215
216#undef half_bitop
217
218#define half_inplace_bitop(__op__, __operator__) \
219 inline _MLX_Float16& __operator__(_MLX_Float16& lhs, _MLX_Float16 rhs) { \
220 lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \
221 return lhs; \
222 } \
223 inline _MLX_Float16& __operator__(_MLX_Float16& lhs, uint16_t rhs) { \
224 lhs.bits_ = lhs.bits_ __op__ rhs; \
225 return lhs; \
226 }
227
228half_inplace_bitop(|, operator|=);
229half_inplace_bitop(&, operator&=);
230half_inplace_bitop(^, operator^=);
231
232#undef half_inplace_bitop
233
234} // namespace mlx::core
#define __MLX_HALF_NAN__
Definition fp16.h:10
#define half_bitop(__op__, __operator__)
Definition fp16.h:195
#define half_inplace_bitop(__op__, __operator__)
Definition fp16.h:218
#define half_inplace_op(__op__, __operator__)
Definition fp16.h:176
#define half_compop(__op__, __operator__)
Definition fp16.h:151
#define half_binop(__op__, __operator__)
Definition fp16.h:132
array operator-(const array &a)
Definition allocator.h:7
Definition fp16.h:21
_MLX_Float16(_MLX_Float16 const &)=default
_MLX_Float16 & operator=(const float &x)
Definition fp16.h:36
uint16_t bits_
Definition fp16.h:22
_MLX_Float16 & operator=(std::vector< bool >::reference x)
Definition fp16.h:31
_MLX_Float16(const float &x)
Definition fp16.h:41
uint32_t u
Definition bf16.h:17
float f
Definition bf16.h:16