mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
Affine quant always in fp32 (#1925)
* do affine quant in fp32 * static cast
This commit is contained in:
@@ -2015,9 +2015,9 @@ template <typename T, const int group_size, const int bits>
|
||||
device T* biases [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
constexpr T eps = T(1e-7);
|
||||
constexpr float eps = 1e-7;
|
||||
constexpr int simd_size = 32;
|
||||
constexpr T n_bins = (1 << bits) - 1;
|
||||
constexpr float n_bins = (1 << bits) - 1;
|
||||
constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||
constexpr int values_per_reduce = group_size / simd_size;
|
||||
constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
|
||||
@@ -2036,13 +2036,13 @@ template <typename T, const int group_size, const int bits>
|
||||
? offset * writes_per_pack
|
||||
: offset * bytes_per_pack / writes_per_reduce;
|
||||
|
||||
T w_thread[values_per_reduce];
|
||||
T w_min = Limits<T>::max;
|
||||
T w_max = 0;
|
||||
float w_thread[values_per_reduce];
|
||||
float w_min = Limits<T>::max;
|
||||
float w_max = 0;
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < values_per_reduce; i++) {
|
||||
T val = w[in_index + i];
|
||||
float val = w[in_index + i];
|
||||
w_thread[i] = val;
|
||||
w_min = min(w_min, val);
|
||||
w_max = max(w_max, val);
|
||||
@@ -2051,20 +2051,20 @@ template <typename T, const int group_size, const int bits>
|
||||
w_min = simd_min(w_min);
|
||||
w_max = simd_max(w_max);
|
||||
|
||||
T scale = max((w_max - w_min) / n_bins, eps);
|
||||
float scale = max((w_max - w_min) / n_bins, eps);
|
||||
bool side = abs(w_min) > abs(w_max);
|
||||
scale = side ? scale : -scale;
|
||||
T edge = side ? w_min : w_max;
|
||||
T q0 = round(edge / scale);
|
||||
float edge = side ? w_min : w_max;
|
||||
float q0 = round(edge / scale);
|
||||
bool at_zero = q0 == 0.0f;
|
||||
scale = at_zero ? scale : edge / q0;
|
||||
T bias = at_zero ? T(0) : edge;
|
||||
float bias = at_zero ? 0 : edge;
|
||||
|
||||
// Write out the scales and biases
|
||||
size_t gindex = in_index / group_size;
|
||||
if (in_index % group_size == 0) {
|
||||
scales[gindex] = scale;
|
||||
biases[gindex] = bias;
|
||||
scales[gindex] = static_cast<T>(scale);
|
||||
biases[gindex] = static_cast<T>(bias);
|
||||
}
|
||||
|
||||
// We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t
|
||||
|
Reference in New Issue
Block a user