Add mode parameter for quantization (#2499)

* add mode parameter for quantization

* mxfp4 quantize/dequantize + start of optional biases

* mxfp4 works

* speedup

* cpu mxfp4

* fix

* fix test tol

* fix

* refactor

* add quant mode enum
This commit is contained in:
Awni Hannun
2025-08-28 06:45:26 -07:00
committed by GitHub
parent 7ef8a6f2d5
commit 70560b6bd5
28 changed files with 3635 additions and 757 deletions

View File

@@ -762,255 +762,14 @@ bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_;
}
array pack_and_quantize(
array& packed_w,
const array& scales,
const array& biases,
int bits,
const Stream& s) {
int el_per_int = 32 / bits;
array zero(0, packed_w.dtype());
array n_bins((1 << bits) - 1, packed_w.dtype()); // 2**bits - 1
packed_w = astype(
clip(
round(divide(subtract(packed_w, biases, s), scales, s), s),
zero,
n_bins,
s),
uint32,
s);
if (is_power_of_2(bits)) {
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
packed_w = sum(
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
} else {
// This is slow but we have fast GPU/CPU versions of this function so we
// shouldn't be here often.
packed_w = expand_dims(packed_w, /* axis= */ -1, s);
packed_w = bitwise_and(
right_shift(packed_w, arange(bits, uint32, s), s),
array({1}, uint32),
s);
auto new_shape = packed_w.shape();
new_shape[new_shape.size() - 2] = -1;
new_shape.back() = 32;
packed_w = reshape(packed_w, new_shape, s);
array shifts = arange(32, uint32, s);
packed_w =
sum(left_shift(packed_w, shifts, s),
/* axis= */ -1,
/* keepdims= */ false,
s);
}
return packed_w;
}
std::tuple<array, array, array>
affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
auto s = to_stream(s_);
if (group_size != 32 && group_size != 64 && group_size != 128) {
std::ostringstream msg;
msg << "[quantize] The requested group size " << group_size
<< " is not supported. The supported group sizes are 32, 64, and 128.";
throw std::invalid_argument(msg.str());
}
if (bits < 2 || bits > 8 || bits == 7) {
std::ostringstream msg;
msg << "[quantize] The requested number of bits " << bits
<< " is not supported. The supported bits are 2, 3, 4, 5, 6 and 8.";
throw std::invalid_argument(msg.str());
}
if (w.ndim() < 2) {
std::ostringstream msg;
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
<< "but it has only " << w.ndim() << ".";
throw std::invalid_argument(msg.str());
}
if ((w.shape(-1) % group_size) != 0) {
std::ostringstream msg;
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
<< "the quantization group size " << group_size
<< ". However the provided " << " matrix has shape " << w.shape();
throw std::invalid_argument(msg.str());
}
auto fallback = [group_size, bits, s](
const std::vector<array>& inputs) -> std::vector<array> {
auto& w = inputs[0];
auto wshape = w.shape();
wshape.back() = -1;
array zero(0, float32);
array n_bins((1 << bits) - 1, float32); // 2**bits - 1
array eps(1e-7, float32);
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
w_max = astype(w_max, float32, s);
w_min = astype(w_min, float32, s);
array mask = greater(abs(w_min, s), abs(w_max, s), s);
array scales =
maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
scales = where(mask, scales, negative(scales, s), s);
array edge = where(mask, w_min, w_max, s);
array q0 = round(divide(edge, scales, s), s);
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
array biases = where(equal(q0, zero, s), zero, edge, s);
packed_w = pack_and_quantize(packed_w, scales, biases, bits, s);
scales = astype(scales, w.dtype(), s);
biases = astype(biases, w.dtype(), s);
return {
reshape(packed_w, wshape, s),
reshape(scales, wshape, s),
reshape(biases, wshape, s),
};
};
auto wq_shape = w.shape();
wq_shape.back() = w.shape(-1) * bits / 32;
auto sshape = w.shape();
sshape.back() = w.shape(-1) / group_size;
auto outputs = array::make_arrays(
{std::move(wq_shape), sshape, sshape},
{uint32, w.dtype(), w.dtype()},
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
{w});
return {outputs[0], outputs[1], outputs[2]};
}
array affine_dequantize(
const array& w,
const array& scales,
const array& biases,
int group_size,
int bits,
StreamOrDevice s_) {
if (bits <= 0) {
std::ostringstream msg;
msg << "[dequantize] Invalid value for bits: " << bits;
throw std::invalid_argument(msg.str());
}
if (group_size <= 0) {
std::ostringstream msg;
msg << "[dequantize] Invalid value for group_size: " << group_size;
throw std::invalid_argument(msg.str());
}
if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) {
std::ostringstream msg;
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
<< "but it has only " << w.ndim() << ".";
throw std::invalid_argument(msg.str());
}
auto wshape = w.shape();
auto sshape = scales.shape();
auto bshape = biases.shape();
wshape.back() = -1;
sshape.back() = -1;
bshape.back() = -1;
if (wshape != sshape || wshape != bshape) {
throw std::invalid_argument(
"[dequantize] Shape of scales and biases does not match the matrix");
}
if (w.dtype() != uint32) {
throw std::invalid_argument(
"[dequantize] The matrix should be given as a uint32");
}
// Packing into uint32
int out_size = w.shape(-1) * 32 / bits;
if (out_size != scales.shape(-1) * group_size) {
std::ostringstream msg;
msg << "[dequantize] Shape of scales and biases does not match the matrix "
<< "given the quantization parameters. Provided matrix of shape "
<< w.shape() << " and scales/biases of shape " << scales.shape()
<< " with group_size=" << group_size << " and bits=" << bits << ".";
throw std::invalid_argument(msg.str());
}
auto s = to_stream(s_);
auto fallback =
[wshape = std::move(wshape),
sshape = std::move(sshape),
group_size,
bits,
s](const std::vector<array>& inputs) mutable -> std::vector<array> {
auto w = inputs[0];
auto& scales = inputs[1];
auto& biases = inputs[2];
if (is_power_of_2(bits)) {
std::vector<array> parts;
for (int start = 0; start < 32; start += bits) {
int shift_left = 32 - (start + bits);
int shift_right = shift_left + start;
parts.push_back(expand_dims(
right_shift(
left_shift(w, array(32 - (start + bits), uint32), s),
array(32 - bits, uint32),
s),
-1,
s));
}
w = concatenate(parts, -1, s);
} else {
w = expand_dims(w, /* axis= */ -1, s);
w = bitwise_and(
right_shift(w, arange(32, uint32, s), s), array({1}, uint32), s);
auto new_shape = w.shape();
new_shape[new_shape.size() - 2] = -1;
new_shape.back() = bits;
w = reshape(w, new_shape, s);
array shifts = arange(bits, uint32, s);
w = sum(
left_shift(w, shifts, s), /* axis= */ -1, /* keepdims= */ false, s);
}
// Dequantize
wshape.push_back(group_size);
w = reshape(w, wshape, s);
w = multiply(w, expand_dims(scales, -1, s), s);
w = add(w, expand_dims(biases, -1, s), s);
w = reshape(w, sshape, s);
return {w};
};
if (s.device == Device::gpu) {
auto out_shape = w.shape();
out_shape.back() = out_size;
return array(
std::move(out_shape),
scales.dtype(),
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, true),
{w, scales, biases});
}
return fallback({w, scales, biases})[0];
}
bool AffineQuantize::is_equivalent(const Primitive& other) const {
const AffineQuantize& p_other = static_cast<const AffineQuantize&>(other);
bool Quantize::is_equivalent(const Primitive& other) const {
const Quantize& p_other = static_cast<const Quantize&>(other);
return (
p_other.group_size_ == group_size_ && p_other.bits_ == bits_ &&
p_other.dequantize_ == dequantize_);
p_other.mode_ == mode_ && p_other.dequantize_ == dequantize_);
}
std::vector<Shape> AffineQuantize::output_shapes(
const std::vector<array>& inputs) {
std::vector<Shape> Quantize::output_shapes(const std::vector<array>& inputs) {
auto& w = inputs[0];
if (dequantize_) {
auto out_size = w.shape(-1) * 32 / bits_;
@@ -1022,8 +781,12 @@ std::vector<Shape> AffineQuantize::output_shapes(
wq_shape.back() = w.shape(-1) * bits_ / 32;
auto sshape = w.shape();
sshape.back() = w.shape(-1) / group_size_;
auto bshape = sshape;
return {std::move(wq_shape), std::move(sshape), std::move(bshape)};
if (inputs.size() == 2) {
return {std::move(wq_shape), std::move(sshape)};
} else {
auto bshape = sshape;
return {std::move(wq_shape), std::move(sshape), std::move(bshape)};
}
}
}