mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Improvements in the quantizer and dequantization kernel (#1061)
This commit is contained in:
parent
7f7b9662ea
commit
17f57df797
@ -205,13 +205,10 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
else if (bits == 4) {
|
else if (bits == 4) {
|
||||||
const thread uint16_t* ws = (const thread uint16_t*)w;
|
U s[2] = {scale, scale / 16.0f};
|
||||||
U s[4] = {scale, scale / 16.0f, scale / 256.0f, scale / 4096.0f};
|
for (int i = 0; i < (values_per_thread / 2); i++) {
|
||||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
|
||||||
result[4 * i] += x * (s[0] * (ws[i] & 0x000f) + bias);
|
result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
|
||||||
result[4 * i + 1] += x * (s[1] * (ws[i] & 0x00f0) + bias);
|
|
||||||
result[4 * i + 2] += x * (s[2] * (ws[i] & 0x0f00) + bias);
|
|
||||||
result[4 * i + 3] += x * (s[3] * (ws[i] & 0xf000) + bias);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -244,17 +241,10 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
else if (bits == 4) {
|
else if (bits == 4) {
|
||||||
const device uint16_t* ws = (const device uint16_t*)w;
|
U s[2] = {scale, scale / static_cast<U>(16.0f)};
|
||||||
U s[4] = {
|
for (int i = 0; i < (N / 2); i++) {
|
||||||
scale,
|
w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias;
|
||||||
scale / static_cast<U>(16.0f),
|
w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias;
|
||||||
scale / static_cast<U>(256.0f),
|
|
||||||
scale / static_cast<U>(4096.0f)};
|
|
||||||
for (int i = 0; i < (N / 4); i++) {
|
|
||||||
w_local[4 * i] = s[0] * (ws[i] & 0x000f) + bias;
|
|
||||||
w_local[4 * i + 1] = s[1] * (ws[i] & 0x00f0) + bias;
|
|
||||||
w_local[4 * i + 2] = s[2] * (ws[i] & 0x0f00) + bias;
|
|
||||||
w_local[4 * i + 3] = s[3] * (ws[i] & 0xf000) + bias;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
24
mlx/ops.cpp
24
mlx/ops.cpp
@ -3275,7 +3275,9 @@ std::tuple<array, array, array> quantize(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Compute some constants used for the quantization
|
// Compute some constants used for the quantization
|
||||||
int n_bins = (1 << bits) - 1; // 2**bits - 1
|
array n_bins((1 << bits) - 1, w.dtype()); // 2**bits - 1
|
||||||
|
array eps(1e-7, w.dtype());
|
||||||
|
array zero(0, w.dtype());
|
||||||
int el_per_int = 32 / bits;
|
int el_per_int = 32 / bits;
|
||||||
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
|
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
|
||||||
shifts = reshape(shifts, {1, 1, -1}, s);
|
shifts = reshape(shifts, {1, 1, -1}, s);
|
||||||
@ -3299,16 +3301,22 @@ std::tuple<array, array, array> quantize(
|
|||||||
reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s);
|
reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s);
|
||||||
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||||
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||||
array scales = maximum(
|
|
||||||
divide(subtract(w_max, w_min, s), array(n_bins, w.dtype()), s),
|
array mask = greater(abs(w_min, s), abs(w_max, s), s);
|
||||||
array(1e-7, w.dtype()),
|
array scales = maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
|
||||||
s);
|
scales = where(mask, scales, negative(scales), s);
|
||||||
// making sure that 0 is represented exactly in the resulting quantization
|
array edge = where(mask, w_min, w_max, s);
|
||||||
array biases = multiply(round(divide(w_min, scales, s), s), scales, s);
|
array q0 = round(divide(edge, scales, s), s);
|
||||||
|
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
|
||||||
|
array biases = where(equal(q0, zero, s), zero, edge);
|
||||||
|
|
||||||
// Quantize and pack w
|
// Quantize and pack w
|
||||||
packed_w = astype(
|
packed_w = astype(
|
||||||
round(divide(subtract(packed_w, biases, s), scales, s), s), uint32);
|
clip(
|
||||||
|
round(divide(subtract(packed_w, biases, s), scales, s), s),
|
||||||
|
zero,
|
||||||
|
n_bins),
|
||||||
|
uint32);
|
||||||
packed_w = reshape(packed_w, {w.shape(0), -1, el_per_int}, s);
|
packed_w = reshape(packed_w, {w.shape(0), -1, el_per_int}, s);
|
||||||
packed_w = sum(
|
packed_w = sum(
|
||||||
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
||||||
|
@ -16,7 +16,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
|
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
|
||||||
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
|
errors = (w - w_hat).abs().reshape(*scales.shape, -1)
|
||||||
eps = 1e-6
|
eps = 1e-6
|
||||||
self.assertTrue((2 * errors <= (scales[..., None] + eps)).all())
|
self.assertTrue((errors <= (scales[..., None] + eps).abs()).all())
|
||||||
|
|
||||||
# test quantize/dequantize 0s
|
# test quantize/dequantize 0s
|
||||||
a = mx.zeros((256, 512))
|
a = mx.zeros((256, 512))
|
||||||
|
Loading…
Reference in New Issue
Block a user