mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Improvements in the quantizer and dequantization kernel (#1061)
This commit is contained in:
committed by
GitHub
parent
7f7b9662ea
commit
17f57df797
24
mlx/ops.cpp
24
mlx/ops.cpp
@@ -3275,7 +3275,9 @@ std::tuple<array, array, array> quantize(
|
||||
}
|
||||
|
||||
// Compute some constants used for the quantization
|
||||
int n_bins = (1 << bits) - 1; // 2**bits - 1
|
||||
array n_bins((1 << bits) - 1, w.dtype()); // 2**bits - 1
|
||||
array eps(1e-7, w.dtype());
|
||||
array zero(0, w.dtype());
|
||||
int el_per_int = 32 / bits;
|
||||
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
|
||||
shifts = reshape(shifts, {1, 1, -1}, s);
|
||||
@@ -3299,16 +3301,22 @@ std::tuple<array, array, array> quantize(
|
||||
reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s);
|
||||
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
array scales = maximum(
|
||||
divide(subtract(w_max, w_min, s), array(n_bins, w.dtype()), s),
|
||||
array(1e-7, w.dtype()),
|
||||
s);
|
||||
// making sure that 0 is represented exactly in the resulting quantization
|
||||
array biases = multiply(round(divide(w_min, scales, s), s), scales, s);
|
||||
|
||||
array mask = greater(abs(w_min, s), abs(w_max, s), s);
|
||||
array scales = maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
|
||||
scales = where(mask, scales, negative(scales), s);
|
||||
array edge = where(mask, w_min, w_max, s);
|
||||
array q0 = round(divide(edge, scales, s), s);
|
||||
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
|
||||
array biases = where(equal(q0, zero, s), zero, edge);
|
||||
|
||||
// Quantize and pack w
|
||||
packed_w = astype(
|
||||
round(divide(subtract(packed_w, biases, s), scales, s), s), uint32);
|
||||
clip(
|
||||
round(divide(subtract(packed_w, biases, s), scales, s), s),
|
||||
zero,
|
||||
n_bins),
|
||||
uint32);
|
||||
packed_w = reshape(packed_w, {w.shape(0), -1, el_per_int}, s);
|
||||
packed_w = sum(
|
||||
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
||||
|
||||
Reference in New Issue
Block a user