3 and 6 bit quantization (#1613)

* Support 3 and 6 bit quantization
This commit is contained in:
Alex Barron
2024-11-22 10:22:13 -08:00
committed by GitHub
parent 0c5eea226b
commit c79f6a4a8c
12 changed files with 633 additions and 419 deletions

View File

@@ -686,13 +686,11 @@ array pack_and_quantize(
array& packed_w,
const array& scales,
const array& biases,
int group_size,
int bits,
const Stream& s) {
int el_per_int = 32 / bits;
array zero(0, packed_w.dtype());
array n_bins((1 << bits) - 1, packed_w.dtype()); // 2**bits - 1
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
packed_w = astype(
clip(
round(divide(subtract(packed_w, biases, s), scales, s), s),
@@ -701,9 +699,30 @@ array pack_and_quantize(
s),
uint32,
s);
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
packed_w = sum(
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
if (is_power_of_2(bits)) {
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
packed_w = sum(
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
} else {
// This is slow but we have fast GPU/CPU versions of this function so we
// shouldn't be here often.
packed_w = expand_dims(packed_w, /* axis= */ -1, s);
packed_w = bitwise_and(
right_shift(packed_w, arange(bits, uint32, s), s),
array({1}, uint32),
s);
auto new_shape = packed_w.shape();
new_shape[new_shape.size() - 2] = -1;
new_shape.back() = 32;
packed_w = reshape(packed_w, new_shape, s);
array shifts = arange(32, uint32, s);
packed_w =
sum(left_shift(packed_w, shifts, s),
/* axis= */ -1,
/* keepdims= */ false,
s);
}
return packed_w;
}
@@ -718,10 +737,10 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
throw std::invalid_argument(msg.str());
}
if (bits != 2 && bits != 4 && bits != 8) {
if (bits != 2 && bits != 3 && bits != 4 && bits != 6 && bits != 8) {
std::ostringstream msg;
msg << "[quantize] The requested number of bits " << bits
<< " is not supported. The supported bits are 2, 4 and 8.";
<< " is not supported. The supported bits are 2, 3, 4, 6 and 8.";
throw std::invalid_argument(msg.str());
}
@@ -740,9 +759,7 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
throw std::invalid_argument(msg.str());
}
int el_per_int = 32 / bits;
auto fallback = [group_size, bits, el_per_int, s](
auto fallback = [group_size, bits, s](
const std::vector<array>& inputs) -> std::vector<array> {
auto& w = inputs[0];
auto wshape = w.shape();
@@ -765,7 +782,7 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
array biases = where(equal(q0, zero, s), zero, edge, s);
packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s);
packed_w = pack_and_quantize(packed_w, scales, biases, bits, s);
return {
reshape(packed_w, wshape, s),
reshape(scales, wshape, s),
@@ -774,7 +791,7 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
};
auto wq_shape = w.shape();
wq_shape.back() = w.shape(-1) / el_per_int;
wq_shape.back() = w.shape(-1) * bits / 32;
auto sshape = w.shape();
sshape.back() = w.shape(-1) / group_size;
auto outputs = array::make_arrays(
@@ -785,39 +802,6 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
return {outputs[0], outputs[1], outputs[2]};
}
array affine_quantize(
const array& w,
const array& scales,
const array& biases,
int group_size,
int bits,
StreamOrDevice s_) {
auto s = to_stream(s_);
int el_per_int = 32 / bits;
auto fallback = [group_size, bits, el_per_int, s](
const std::vector<array>& inputs) -> std::vector<array> {
auto& w = inputs[0];
auto scales = expand_dims(inputs[1], -1, s);
auto biases = expand_dims(inputs[2], -1, s);
auto wshape = w.shape();
wshape.back() = -1;
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s);
return {reshape(packed_w, wshape, s)};
};
auto out_shape = w.shape();
out_shape.back() = w.shape(-1) / el_per_int;
return array(
std::move(out_shape),
uint32,
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
{w, scales, biases});
}
array affine_dequantize(
const array& w,
const array& scales,
@@ -860,9 +844,9 @@ array affine_dequantize(
}
// Packing into uint32
int el_per_int = 32 / bits;
int out_size = w.shape(-1) * 32 / bits;
if (w.shape(-1) * el_per_int != scales.shape(-1) * group_size) {
if (out_size != scales.shape(-1) * group_size) {
std::ostringstream msg;
msg << "[dequantize] Shape of scales and biases does not match the matrix "
<< "given the quantization parameters. Provided matrix of shape "
@@ -873,40 +857,52 @@ array affine_dequantize(
auto s = to_stream(s_);
auto fallback =
[&wshape, &sshape, &scales, &biases, group_size, bits, el_per_int, s](
const std::vector<array>& inputs) -> std::vector<array> {
auto& w = inputs[0];
auto fallback = [&wshape, &sshape, &scales, &biases, group_size, bits, s](
const std::vector<array>& inputs) -> std::vector<array> {
auto w = inputs[0];
auto& scales = inputs[1];
auto& biases = inputs[2];
std::vector<array> parts;
for (int start = 0; start < 32; start += bits) {
int shift_left = 32 - (start + bits);
int shift_right = shift_left + start;
if (is_power_of_2(bits)) {
std::vector<array> parts;
for (int start = 0; start < 32; start += bits) {
int shift_left = 32 - (start + bits);
int shift_right = shift_left + start;
parts.push_back(expand_dims(
right_shift(
left_shift(w, array(32 - (start + bits), uint32), s),
array(32 - bits, uint32),
s),
-1,
s));
parts.push_back(expand_dims(
right_shift(
left_shift(w, array(32 - (start + bits), uint32), s),
array(32 - bits, uint32),
s),
-1,
s));
}
w = concatenate(parts, -1, s);
} else {
w = expand_dims(w, /* axis= */ -1, s);
w = bitwise_and(
right_shift(w, arange(32, uint32, s), s), array({1}, uint32), s);
auto new_shape = w.shape();
new_shape[new_shape.size() - 2] = -1;
new_shape.back() = bits;
w = reshape(w, new_shape, s);
array shifts = arange(bits, uint32, s);
w = sum(
left_shift(w, shifts, s), /* axis= */ -1, /* keepdims= */ false, s);
}
array w_full = concatenate(parts, -1, s);
// Dequantize
wshape.push_back(group_size);
w_full = reshape(w_full, wshape, s);
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
w_full = add(w_full, expand_dims(biases, -1, s), s);
w_full = reshape(w_full, sshape, s);
w = reshape(w, wshape, s);
w = multiply(w, expand_dims(scales, -1, s), s);
w = add(w, expand_dims(biases, -1, s), s);
w = reshape(w, sshape, s);
return {w_full};
return {w};
};
if (s.device == Device::gpu) {
auto out_shape = w.shape();
out_shape.back() = w.shape(-1) * el_per_int;
out_shape.back() = out_size;
return array(
std::move(out_shape),
scales.dtype(),