mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fused Affine Quantize/Dequantize ops (#1282)
* Add fast affine dequantize * add full quantize kernel * fused kernel with scale/bias computation * fix docstring * fix no jit error * fix test * test fix * reduce fast api to only affine_quantize
This commit is contained in:
156
mlx/ops.cpp
156
mlx/ops.cpp
@@ -6,6 +6,7 @@
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/transforms.h"
|
||||
@@ -3356,89 +3357,7 @@ std::tuple<array, array, array> quantize(
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (group_size != 32 && group_size != 64 && group_size != 128) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The requested group size " << group_size
|
||||
<< " is not supported. The supported group sizes are 64 and 128.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (bits != 2 && bits != 4 && bits != 8) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The requested number of bits " << bits
|
||||
<< " is not supported. The supported bits are 2, 4 and 8.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (w.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||
<< "but it has only " << w.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if ((w.shape(-1) % group_size) != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
|
||||
<< "the quantization group size " << group_size
|
||||
<< ". However the provided " << " matrix has shape " << w.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Compute some constants used for the quantization
|
||||
array n_bins((1 << bits) - 1, w.dtype()); // 2**bits - 1
|
||||
array eps(1e-7, w.dtype());
|
||||
array zero(0, w.dtype());
|
||||
int el_per_int = 32 / bits;
|
||||
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
|
||||
shifts = reshape(shifts, {1, 1, -1}, s);
|
||||
|
||||
// Check that the w matrix will fill up a whole SIMD.
|
||||
// This is an implementation detail which should be removed in the future but
|
||||
// at least we bail out early which will result in a nice readable error.
|
||||
//
|
||||
// Hopefully nobody is quantizing matrices that small anyway.
|
||||
if (w.shape(-1) < 32 * el_per_int) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The feature dimension (2nd dimension of the matrix) is "
|
||||
<< "too small for quantization. We support >=512 for 2 bits, "
|
||||
<< ">= 256 for 4 bits and >= 128 for 8 bits. The provided matrix has "
|
||||
<< "shape " << w.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Prepare the shape for the outputs.
|
||||
auto wshape = w.shape();
|
||||
wshape.back() = -1;
|
||||
|
||||
// Compute scales and biases
|
||||
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
|
||||
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
|
||||
array mask = greater(abs(w_min, s), abs(w_max, s), s);
|
||||
array scales = maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
|
||||
scales = where(mask, scales, negative(scales), s);
|
||||
array edge = where(mask, w_min, w_max, s);
|
||||
array q0 = round(divide(edge, scales, s), s);
|
||||
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
|
||||
array biases = where(equal(q0, zero, s), zero, edge);
|
||||
|
||||
// Quantize and pack w
|
||||
packed_w = astype(
|
||||
clip(
|
||||
round(divide(subtract(packed_w, biases, s), scales, s), s),
|
||||
zero,
|
||||
n_bins),
|
||||
uint32);
|
||||
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
|
||||
packed_w = sum(
|
||||
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
||||
|
||||
return std::make_tuple(
|
||||
reshape(packed_w, wshape, s),
|
||||
reshape(scales, wshape, s),
|
||||
reshape(biases, wshape, s));
|
||||
return fast::affine_quantize(w, group_size, bits);
|
||||
}
|
||||
|
||||
array dequantize(
|
||||
@@ -3448,76 +3367,7 @@ array dequantize(
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (bits <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Invalid value for bits: " << bits;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (group_size <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Invalid value for group_size: " << group_size;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||
<< "but it has only " << w.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto wshape = w.shape();
|
||||
auto sshape = scales.shape();
|
||||
auto bshape = biases.shape();
|
||||
wshape.back() = -1;
|
||||
sshape.back() = -1;
|
||||
bshape.back() = -1;
|
||||
|
||||
if (wshape != sshape || wshape != bshape) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] Shape of scales and biases does not match the matrix");
|
||||
}
|
||||
|
||||
if (w.dtype() != uint32) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] The matrix should be given as a uint32");
|
||||
}
|
||||
|
||||
// Compute some constants for the dequantization
|
||||
int el_per_int = 32 / bits;
|
||||
|
||||
if (w.shape(-1) * el_per_int != scales.shape(-1) * group_size) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Shape of scales and biases does not match the matrix "
|
||||
<< "given the quantization parameters. Provided matrix of shape "
|
||||
<< w.shape() << " and scales/biases of shape " << scales.shape()
|
||||
<< " with group_size=" << group_size << " and bits=" << bits << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Extract the pieces from the passed quantized matrix
|
||||
std::vector<array> parts;
|
||||
for (int start = 0; start < 32; start += bits) {
|
||||
int shift_left = 32 - (start + bits);
|
||||
int shift_right = shift_left + start;
|
||||
|
||||
parts.push_back(expand_dims(
|
||||
right_shift(
|
||||
left_shift(w, array(32 - (start + bits), uint32), s),
|
||||
array(32 - bits, uint32),
|
||||
s),
|
||||
-1,
|
||||
s));
|
||||
}
|
||||
array w_full = concatenate(parts, -1, s);
|
||||
|
||||
// Dequantize
|
||||
wshape.push_back(group_size);
|
||||
w_full = reshape(w_full, wshape, s);
|
||||
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
|
||||
w_full = add(w_full, expand_dims(biases, -1, s), s);
|
||||
w_full = reshape(w_full, sshape, s);
|
||||
|
||||
return w_full;
|
||||
return fast::affine_dequantize(w, scales, biases, group_size, bits, s);
|
||||
}
|
||||
|
||||
array gather_qmm(
|
||||
|
||||
Reference in New Issue
Block a user