mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add mode parameter for quantization (#2499)
* add mode parameter for quantization * mxfp4 quantize/dequantize + start of optional biases * mxfp4 works * speedup * cpu mxfp4 * fix * fix test tol * fix * refactor * add quant mode enum
This commit is contained in:
257
mlx/fast.cpp
257
mlx/fast.cpp
@@ -762,255 +762,14 @@ bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
|
||||
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_;
|
||||
}
|
||||
|
||||
array pack_and_quantize(
|
||||
array& packed_w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
int bits,
|
||||
const Stream& s) {
|
||||
int el_per_int = 32 / bits;
|
||||
array zero(0, packed_w.dtype());
|
||||
array n_bins((1 << bits) - 1, packed_w.dtype()); // 2**bits - 1
|
||||
packed_w = astype(
|
||||
clip(
|
||||
round(divide(subtract(packed_w, biases, s), scales, s), s),
|
||||
zero,
|
||||
n_bins,
|
||||
s),
|
||||
uint32,
|
||||
s);
|
||||
if (is_power_of_2(bits)) {
|
||||
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
|
||||
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
|
||||
packed_w = sum(
|
||||
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
||||
} else {
|
||||
// This is slow but we have fast GPU/CPU versions of this function so we
|
||||
// shouldn't be here often.
|
||||
packed_w = expand_dims(packed_w, /* axis= */ -1, s);
|
||||
packed_w = bitwise_and(
|
||||
right_shift(packed_w, arange(bits, uint32, s), s),
|
||||
array({1}, uint32),
|
||||
s);
|
||||
auto new_shape = packed_w.shape();
|
||||
new_shape[new_shape.size() - 2] = -1;
|
||||
new_shape.back() = 32;
|
||||
packed_w = reshape(packed_w, new_shape, s);
|
||||
array shifts = arange(32, uint32, s);
|
||||
packed_w =
|
||||
sum(left_shift(packed_w, shifts, s),
|
||||
/* axis= */ -1,
|
||||
/* keepdims= */ false,
|
||||
s);
|
||||
}
|
||||
return packed_w;
|
||||
}
|
||||
|
||||
std::tuple<array, array, array>
|
||||
affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
||||
auto s = to_stream(s_);
|
||||
|
||||
if (group_size != 32 && group_size != 64 && group_size != 128) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The requested group size " << group_size
|
||||
<< " is not supported. The supported group sizes are 32, 64, and 128.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (bits < 2 || bits > 8 || bits == 7) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The requested number of bits " << bits
|
||||
<< " is not supported. The supported bits are 2, 3, 4, 5, 6 and 8.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (w.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||
<< "but it has only " << w.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if ((w.shape(-1) % group_size) != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
|
||||
<< "the quantization group size " << group_size
|
||||
<< ". However the provided " << " matrix has shape " << w.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto fallback = [group_size, bits, s](
|
||||
const std::vector<array>& inputs) -> std::vector<array> {
|
||||
auto& w = inputs[0];
|
||||
auto wshape = w.shape();
|
||||
wshape.back() = -1;
|
||||
|
||||
array zero(0, float32);
|
||||
array n_bins((1 << bits) - 1, float32); // 2**bits - 1
|
||||
array eps(1e-7, float32);
|
||||
|
||||
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
|
||||
|
||||
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
w_max = astype(w_max, float32, s);
|
||||
w_min = astype(w_min, float32, s);
|
||||
|
||||
array mask = greater(abs(w_min, s), abs(w_max, s), s);
|
||||
array scales =
|
||||
maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
|
||||
scales = where(mask, scales, negative(scales, s), s);
|
||||
array edge = where(mask, w_min, w_max, s);
|
||||
array q0 = round(divide(edge, scales, s), s);
|
||||
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
|
||||
array biases = where(equal(q0, zero, s), zero, edge, s);
|
||||
|
||||
packed_w = pack_and_quantize(packed_w, scales, biases, bits, s);
|
||||
|
||||
scales = astype(scales, w.dtype(), s);
|
||||
biases = astype(biases, w.dtype(), s);
|
||||
return {
|
||||
reshape(packed_w, wshape, s),
|
||||
reshape(scales, wshape, s),
|
||||
reshape(biases, wshape, s),
|
||||
};
|
||||
};
|
||||
|
||||
auto wq_shape = w.shape();
|
||||
wq_shape.back() = w.shape(-1) * bits / 32;
|
||||
auto sshape = w.shape();
|
||||
sshape.back() = w.shape(-1) / group_size;
|
||||
auto outputs = array::make_arrays(
|
||||
{std::move(wq_shape), sshape, sshape},
|
||||
{uint32, w.dtype(), w.dtype()},
|
||||
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
|
||||
{w});
|
||||
return {outputs[0], outputs[1], outputs[2]};
|
||||
}
|
||||
|
||||
array affine_dequantize(
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
int group_size,
|
||||
int bits,
|
||||
StreamOrDevice s_) {
|
||||
if (bits <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Invalid value for bits: " << bits;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (group_size <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Invalid value for group_size: " << group_size;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||
<< "but it has only " << w.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto wshape = w.shape();
|
||||
auto sshape = scales.shape();
|
||||
auto bshape = biases.shape();
|
||||
wshape.back() = -1;
|
||||
sshape.back() = -1;
|
||||
bshape.back() = -1;
|
||||
|
||||
if (wshape != sshape || wshape != bshape) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] Shape of scales and biases does not match the matrix");
|
||||
}
|
||||
|
||||
if (w.dtype() != uint32) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] The matrix should be given as a uint32");
|
||||
}
|
||||
|
||||
// Packing into uint32
|
||||
int out_size = w.shape(-1) * 32 / bits;
|
||||
|
||||
if (out_size != scales.shape(-1) * group_size) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Shape of scales and biases does not match the matrix "
|
||||
<< "given the quantization parameters. Provided matrix of shape "
|
||||
<< w.shape() << " and scales/biases of shape " << scales.shape()
|
||||
<< " with group_size=" << group_size << " and bits=" << bits << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto s = to_stream(s_);
|
||||
|
||||
auto fallback =
|
||||
[wshape = std::move(wshape),
|
||||
sshape = std::move(sshape),
|
||||
group_size,
|
||||
bits,
|
||||
s](const std::vector<array>& inputs) mutable -> std::vector<array> {
|
||||
auto w = inputs[0];
|
||||
auto& scales = inputs[1];
|
||||
auto& biases = inputs[2];
|
||||
if (is_power_of_2(bits)) {
|
||||
std::vector<array> parts;
|
||||
for (int start = 0; start < 32; start += bits) {
|
||||
int shift_left = 32 - (start + bits);
|
||||
int shift_right = shift_left + start;
|
||||
|
||||
parts.push_back(expand_dims(
|
||||
right_shift(
|
||||
left_shift(w, array(32 - (start + bits), uint32), s),
|
||||
array(32 - bits, uint32),
|
||||
s),
|
||||
-1,
|
||||
s));
|
||||
}
|
||||
w = concatenate(parts, -1, s);
|
||||
} else {
|
||||
w = expand_dims(w, /* axis= */ -1, s);
|
||||
w = bitwise_and(
|
||||
right_shift(w, arange(32, uint32, s), s), array({1}, uint32), s);
|
||||
auto new_shape = w.shape();
|
||||
new_shape[new_shape.size() - 2] = -1;
|
||||
new_shape.back() = bits;
|
||||
w = reshape(w, new_shape, s);
|
||||
array shifts = arange(bits, uint32, s);
|
||||
w = sum(
|
||||
left_shift(w, shifts, s), /* axis= */ -1, /* keepdims= */ false, s);
|
||||
}
|
||||
|
||||
// Dequantize
|
||||
wshape.push_back(group_size);
|
||||
w = reshape(w, wshape, s);
|
||||
w = multiply(w, expand_dims(scales, -1, s), s);
|
||||
w = add(w, expand_dims(biases, -1, s), s);
|
||||
w = reshape(w, sshape, s);
|
||||
|
||||
return {w};
|
||||
};
|
||||
|
||||
if (s.device == Device::gpu) {
|
||||
auto out_shape = w.shape();
|
||||
out_shape.back() = out_size;
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
scales.dtype(),
|
||||
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, true),
|
||||
{w, scales, biases});
|
||||
}
|
||||
return fallback({w, scales, biases})[0];
|
||||
}
|
||||
|
||||
bool AffineQuantize::is_equivalent(const Primitive& other) const {
|
||||
const AffineQuantize& p_other = static_cast<const AffineQuantize&>(other);
|
||||
bool Quantize::is_equivalent(const Primitive& other) const {
|
||||
const Quantize& p_other = static_cast<const Quantize&>(other);
|
||||
return (
|
||||
p_other.group_size_ == group_size_ && p_other.bits_ == bits_ &&
|
||||
p_other.dequantize_ == dequantize_);
|
||||
p_other.mode_ == mode_ && p_other.dequantize_ == dequantize_);
|
||||
}
|
||||
|
||||
std::vector<Shape> AffineQuantize::output_shapes(
|
||||
const std::vector<array>& inputs) {
|
||||
std::vector<Shape> Quantize::output_shapes(const std::vector<array>& inputs) {
|
||||
auto& w = inputs[0];
|
||||
if (dequantize_) {
|
||||
auto out_size = w.shape(-1) * 32 / bits_;
|
||||
@@ -1022,8 +781,12 @@ std::vector<Shape> AffineQuantize::output_shapes(
|
||||
wq_shape.back() = w.shape(-1) * bits_ / 32;
|
||||
auto sshape = w.shape();
|
||||
sshape.back() = w.shape(-1) / group_size_;
|
||||
auto bshape = sshape;
|
||||
return {std::move(wq_shape), std::move(sshape), std::move(bshape)};
|
||||
if (inputs.size() == 2) {
|
||||
return {std::move(wq_shape), std::move(sshape)};
|
||||
} else {
|
||||
auto bshape = sshape;
|
||||
return {std::move(wq_shape), std::move(sshape), std::move(bshape)};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user