mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Affine quant always in fp32 (#1925)
* do affine quant in fp32 * static cast
This commit is contained in:
12
mlx/fast.cpp
12
mlx/fast.cpp
@@ -827,14 +827,17 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
||||
auto wshape = w.shape();
|
||||
wshape.back() = -1;
|
||||
|
||||
array zero(0, w.dtype());
|
||||
array n_bins((1 << bits) - 1, w.dtype()); // 2**bits - 1
|
||||
array eps(1e-7, w.dtype());
|
||||
array zero(0, float32);
|
||||
array n_bins((1 << bits) - 1, float32); // 2**bits - 1
|
||||
array eps(1e-7, float32);
|
||||
|
||||
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
|
||||
|
||||
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
w_max = astype(w_max, float32, s);
|
||||
w_min = astype(w_min, float32, s);
|
||||
|
||||
array mask = greater(abs(w_min, s), abs(w_max, s), s);
|
||||
array scales =
|
||||
maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
|
||||
@@ -845,6 +848,9 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
||||
array biases = where(equal(q0, zero, s), zero, edge, s);
|
||||
|
||||
packed_w = pack_and_quantize(packed_w, scales, biases, bits, s);
|
||||
|
||||
scales = astype(scales, w.dtype(), s);
|
||||
biases = astype(biases, w.dtype(), s);
|
||||
return {
|
||||
reshape(packed_w, wshape, s),
|
||||
reshape(scales, wshape, s),
|
||||
|
||||
Reference in New Issue
Block a user