mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Affine quant always in fp32 (#1925)
* do affine quant in fp32 * static cast
This commit is contained in:
parent
3835a428c5
commit
fd0d63ba5b
@ -543,8 +543,8 @@ void quantize(
|
|||||||
T* scales = scales_.data<T>();
|
T* scales = scales_.data<T>();
|
||||||
T* biases = biases_.data<T>();
|
T* biases = biases_.data<T>();
|
||||||
|
|
||||||
T n_bins = (1 << bits) - 1;
|
float n_bins = (1 << bits) - 1;
|
||||||
T eps = 1e-7;
|
float eps = 1e-7;
|
||||||
bool power_of_2_bits = is_power_of_2(bits);
|
bool power_of_2_bits = is_power_of_2(bits);
|
||||||
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||||
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
|
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
|
||||||
@ -554,32 +554,30 @@ void quantize(
|
|||||||
|
|
||||||
for (size_t i = 0; i < n_groups; ++i) {
|
for (size_t i = 0; i < n_groups; ++i) {
|
||||||
size_t w_idx = i * group_size;
|
size_t w_idx = i * group_size;
|
||||||
T w_min = std::numeric_limits<float>::infinity();
|
float w_min = std::numeric_limits<float>::infinity();
|
||||||
T w_max = -w_min;
|
float w_max = -w_min;
|
||||||
for (int j = 0; j < group_size; ++j) {
|
for (int j = 0; j < group_size; ++j) {
|
||||||
w_max = std::max(w_max, w[w_idx + j]);
|
w_max = std::max(w_max, (float)w[w_idx + j]);
|
||||||
w_min = std::min(w_min, w[w_idx + j]);
|
w_min = std::min(w_min, (float)w[w_idx + j]);
|
||||||
}
|
}
|
||||||
bool mask = std::abs(w_min) > std::abs(w_max);
|
bool mask = std::abs(w_min) > std::abs(w_max);
|
||||||
T scale = std::max(T((w_max - w_min) / n_bins), eps);
|
float scale = std::max((w_max - w_min) / n_bins, eps);
|
||||||
scale = mask ? scale : -scale;
|
scale = mask ? scale : -scale;
|
||||||
|
|
||||||
auto edge = mask ? w_min : w_max;
|
float edge = mask ? w_min : w_max;
|
||||||
auto q0 = std::rint(edge / scale);
|
float q0 = std::rint(edge / scale);
|
||||||
if (q0 == 0) {
|
float bias = 0;
|
||||||
scales[i] = scale;
|
if (q0 != 0) {
|
||||||
biases[i] = 0;
|
scale = edge / q0;
|
||||||
} else {
|
bias = edge;
|
||||||
scales[i] = edge / q0;
|
|
||||||
biases[i] = edge;
|
|
||||||
}
|
}
|
||||||
size_t out_idx = i * int_per_group;
|
size_t out_idx = i * int_per_group;
|
||||||
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
|
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
|
||||||
uint32_t out_el = 0;
|
uint32_t out_el = 0;
|
||||||
for (int k = 0; k < el_per_int; ++k) {
|
for (int k = 0; k < el_per_int; ++k) {
|
||||||
T w_el = w[w_idx + j * el_per_int + k];
|
float w_el = w[w_idx + j * el_per_int + k];
|
||||||
w_el = std::rint((w_el - biases[i]) / scales[i]);
|
w_el = std::rint((w_el - bias) / scale);
|
||||||
w_el = std::min(std::max(w_el, T(0)), n_bins);
|
w_el = std::min(std::max(w_el, 0.0f), n_bins);
|
||||||
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
|
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
|
||||||
}
|
}
|
||||||
if (power_of_2_bits) {
|
if (power_of_2_bits) {
|
||||||
@ -590,6 +588,8 @@ void quantize(
|
|||||||
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
|
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
scales[i] = static_cast<T>(scale);
|
||||||
|
biases[i] = static_cast<T>(bias);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2015,9 +2015,9 @@ template <typename T, const int group_size, const int bits>
|
|||||||
device T* biases [[buffer(3)]],
|
device T* biases [[buffer(3)]],
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
constexpr T eps = T(1e-7);
|
constexpr float eps = 1e-7;
|
||||||
constexpr int simd_size = 32;
|
constexpr int simd_size = 32;
|
||||||
constexpr T n_bins = (1 << bits) - 1;
|
constexpr float n_bins = (1 << bits) - 1;
|
||||||
constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||||
constexpr int values_per_reduce = group_size / simd_size;
|
constexpr int values_per_reduce = group_size / simd_size;
|
||||||
constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
|
constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
|
||||||
@ -2036,13 +2036,13 @@ template <typename T, const int group_size, const int bits>
|
|||||||
? offset * writes_per_pack
|
? offset * writes_per_pack
|
||||||
: offset * bytes_per_pack / writes_per_reduce;
|
: offset * bytes_per_pack / writes_per_reduce;
|
||||||
|
|
||||||
T w_thread[values_per_reduce];
|
float w_thread[values_per_reduce];
|
||||||
T w_min = Limits<T>::max;
|
float w_min = Limits<T>::max;
|
||||||
T w_max = 0;
|
float w_max = 0;
|
||||||
|
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int i = 0; i < values_per_reduce; i++) {
|
for (int i = 0; i < values_per_reduce; i++) {
|
||||||
T val = w[in_index + i];
|
float val = w[in_index + i];
|
||||||
w_thread[i] = val;
|
w_thread[i] = val;
|
||||||
w_min = min(w_min, val);
|
w_min = min(w_min, val);
|
||||||
w_max = max(w_max, val);
|
w_max = max(w_max, val);
|
||||||
@ -2051,20 +2051,20 @@ template <typename T, const int group_size, const int bits>
|
|||||||
w_min = simd_min(w_min);
|
w_min = simd_min(w_min);
|
||||||
w_max = simd_max(w_max);
|
w_max = simd_max(w_max);
|
||||||
|
|
||||||
T scale = max((w_max - w_min) / n_bins, eps);
|
float scale = max((w_max - w_min) / n_bins, eps);
|
||||||
bool side = abs(w_min) > abs(w_max);
|
bool side = abs(w_min) > abs(w_max);
|
||||||
scale = side ? scale : -scale;
|
scale = side ? scale : -scale;
|
||||||
T edge = side ? w_min : w_max;
|
float edge = side ? w_min : w_max;
|
||||||
T q0 = round(edge / scale);
|
float q0 = round(edge / scale);
|
||||||
bool at_zero = q0 == 0.0f;
|
bool at_zero = q0 == 0.0f;
|
||||||
scale = at_zero ? scale : edge / q0;
|
scale = at_zero ? scale : edge / q0;
|
||||||
T bias = at_zero ? T(0) : edge;
|
float bias = at_zero ? 0 : edge;
|
||||||
|
|
||||||
// Write out the scales and biases
|
// Write out the scales and biases
|
||||||
size_t gindex = in_index / group_size;
|
size_t gindex = in_index / group_size;
|
||||||
if (in_index % group_size == 0) {
|
if (in_index % group_size == 0) {
|
||||||
scales[gindex] = scale;
|
scales[gindex] = static_cast<T>(scale);
|
||||||
biases[gindex] = bias;
|
biases[gindex] = static_cast<T>(bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
// We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t
|
// We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t
|
||||||
|
12
mlx/fast.cpp
12
mlx/fast.cpp
@ -827,14 +827,17 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
|||||||
auto wshape = w.shape();
|
auto wshape = w.shape();
|
||||||
wshape.back() = -1;
|
wshape.back() = -1;
|
||||||
|
|
||||||
array zero(0, w.dtype());
|
array zero(0, float32);
|
||||||
array n_bins((1 << bits) - 1, w.dtype()); // 2**bits - 1
|
array n_bins((1 << bits) - 1, float32); // 2**bits - 1
|
||||||
array eps(1e-7, w.dtype());
|
array eps(1e-7, float32);
|
||||||
|
|
||||||
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
|
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
|
||||||
|
|
||||||
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||||
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||||
|
w_max = astype(w_max, float32, s);
|
||||||
|
w_min = astype(w_min, float32, s);
|
||||||
|
|
||||||
array mask = greater(abs(w_min, s), abs(w_max, s), s);
|
array mask = greater(abs(w_min, s), abs(w_max, s), s);
|
||||||
array scales =
|
array scales =
|
||||||
maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
|
maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
|
||||||
@ -845,6 +848,9 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
|||||||
array biases = where(equal(q0, zero, s), zero, edge, s);
|
array biases = where(equal(q0, zero, s), zero, edge, s);
|
||||||
|
|
||||||
packed_w = pack_and_quantize(packed_w, scales, biases, bits, s);
|
packed_w = pack_and_quantize(packed_w, scales, biases, bits, s);
|
||||||
|
|
||||||
|
scales = astype(scales, w.dtype(), s);
|
||||||
|
biases = astype(biases, w.dtype(), s);
|
||||||
return {
|
return {
|
||||||
reshape(packed_w, wshape, s),
|
reshape(packed_w, wshape, s),
|
||||||
reshape(scales, wshape, s),
|
reshape(scales, wshape, s),
|
||||||
|
Loading…
Reference in New Issue
Block a user