mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add quantize/dequantize slow path for mxfp8 and nvfp4
This commit is contained in:
187
mlx/ops.cpp
187
mlx/ops.cpp
@@ -4018,7 +4018,8 @@ array conv_general(
|
||||
}
|
||||
|
||||
void validate_mode(std::string_view tag, const std::string& mode) {
|
||||
if (mode != "affine" && mode != "mxfp4") {
|
||||
if (mode != "affine" && mode != "mxfp4" && mode != "mxfp8" &&
|
||||
mode != "nvfp4") {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] Invalid quantization mode '" << mode << "'.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
@@ -4249,51 +4250,67 @@ std::vector<array> quantize(
|
||||
if (mode == "affine") {
|
||||
return affine_quantize(w, group_size, bits, s);
|
||||
} else {
|
||||
if (group_size != 32) {
|
||||
int expected_gs = (mode[0] == 'm') ? 32 : 16;
|
||||
int expected_bits = (mode.back() == '8') ? 8 : 4;
|
||||
if (group_size != expected_gs) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] mxfp4 quantization requires group size 32 "
|
||||
<< "but got " << group_size << ".";
|
||||
msg << "[quantize] " << mode << " quantization requires group size "
|
||||
<< expected_gs << " but got " << group_size << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (bits != 4) {
|
||||
if (bits != expected_bits) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] mxfp4 quantization requires bits to be 4 "
|
||||
<< "but got " << bits << ".";
|
||||
msg << "[quantize] " << mode << " quantization requires bits to be "
|
||||
<< expected_bits << " but got " << bits << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto lut = array({
|
||||
+0.0f,
|
||||
+0.5f,
|
||||
+1.0f,
|
||||
+1.5f,
|
||||
+2.0f,
|
||||
+3.0f,
|
||||
+4.0f,
|
||||
+6.0f,
|
||||
-0.0f,
|
||||
-0.5f,
|
||||
-1.0f,
|
||||
-1.5f,
|
||||
-2.0f,
|
||||
-3.0f,
|
||||
-4.0f,
|
||||
-6.0f,
|
||||
});
|
||||
lut = astype(lut, w.dtype(), s);
|
||||
|
||||
float maxval = (bits == 4) ? 6.0f : 448.0f;
|
||||
auto new_shape = w.shape();
|
||||
new_shape.back() = -1;
|
||||
auto wq = reshape(w, {-1, group_size}, s);
|
||||
auto scales =
|
||||
divide(max(abs(wq, s), -1, true, s), array(6.0f, w.dtype()), s);
|
||||
scales = astype(log2(scales, s), int32, s);
|
||||
wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s);
|
||||
scales = astype(add(scales, array(127, int32), s), uint8, s);
|
||||
wq = argmin(abs(subtract(expand_dims(wq, -1, s), lut, s), s), -1, false, s);
|
||||
auto shifts = power(array(2, uint32), arange(0, 32, 4, uint32, s), s);
|
||||
wq = reshape(wq, {-1, group_size / 8, 8}, s);
|
||||
wq = sum(multiply(wq, shifts, s), -1, false, s);
|
||||
divide(max(abs(wq, s), -1, true, s), array(maxval, w.dtype()), s);
|
||||
if (group_size == 16) {
|
||||
// convert to e4m3
|
||||
scales = to_fp8(scales, s);
|
||||
wq = divide(wq, from_fp8(scales, w.dtype(), s), s);
|
||||
} else {
|
||||
// convert to e8m0
|
||||
auto z = array(0, scales.dtype());
|
||||
scales =
|
||||
where(equal(scales, z, s), z, astype(log2(scales, s), int32, s), s);
|
||||
|
||||
wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s);
|
||||
scales = astype(add(scales, array(127, int32), s), uint8, s);
|
||||
}
|
||||
if (bits == 4) {
|
||||
auto lut = array({
|
||||
+0.0f,
|
||||
+0.5f,
|
||||
+1.0f,
|
||||
+1.5f,
|
||||
+2.0f,
|
||||
+3.0f,
|
||||
+4.0f,
|
||||
+6.0f,
|
||||
-0.0f,
|
||||
-0.5f,
|
||||
-1.0f,
|
||||
-1.5f,
|
||||
-2.0f,
|
||||
-3.0f,
|
||||
-4.0f,
|
||||
-6.0f,
|
||||
});
|
||||
lut = astype(lut, w.dtype(), s);
|
||||
wq = argmin(
|
||||
abs(subtract(expand_dims(wq, -1, s), lut, s), s), -1, false, s);
|
||||
auto shifts = power(array(2, uint32), arange(0, 32, 4, uint32, s), s);
|
||||
wq = reshape(wq, {-1, 4, 8}, s);
|
||||
wq = sum(multiply(wq, shifts, s), -1, false, s);
|
||||
} else {
|
||||
wq = view(to_fp8(wq, s), uint32, s);
|
||||
}
|
||||
wq = reshape(wq, new_shape, s);
|
||||
scales = reshape(scales, new_shape, s);
|
||||
return {std::move(wq), std::move(scales)};
|
||||
@@ -4404,6 +4421,7 @@ array dequantize(
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
const std::string& mode /* = "affine" */,
|
||||
std::optional<Dtype> dtype /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
validate_mode_with_type("dequantize", scales, biases, mode);
|
||||
if (bits <= 0) {
|
||||
@@ -4422,24 +4440,30 @@ array dequantize(
|
||||
}
|
||||
|
||||
if (mode == "affine") {
|
||||
return affine_dequantize(w, scales, *biases, group_size, bits, s);
|
||||
auto out = affine_dequantize(w, scales, *biases, group_size, bits, s);
|
||||
if (dtype) {
|
||||
out = astype(out, *dtype, s);
|
||||
}
|
||||
return out;
|
||||
} else {
|
||||
if (group_size != 32) {
|
||||
int expected_gs = (mode[0] == 'm') ? 32 : 16;
|
||||
int expected_bits = (mode.back() == '8') ? 8 : 4;
|
||||
if (group_size != expected_gs) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] mxfp4 quantization requires group size 32 "
|
||||
<< "but got " << group_size << ".";
|
||||
msg << "[quantize] " << mode << " quantization requires group size "
|
||||
<< expected_gs << " but got " << group_size << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (bits != 4) {
|
||||
if (bits != expected_bits) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] mxfp4 quantization requires bits to be 4 "
|
||||
<< "but got " << bits << ".";
|
||||
msg << "[quantize] " << mode << " quantization requires bits to be "
|
||||
<< expected_bits << " but got " << bits << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (w.ndim() < 2 || scales.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||
msg << "[quantize] The matrix to be dequantized must have at least 2 dimension "
|
||||
<< "but it has only " << w.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
@@ -4470,39 +4494,48 @@ array dequantize(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto dtype = bfloat16;
|
||||
auto lut = array(
|
||||
{
|
||||
+0.0f,
|
||||
+0.5f,
|
||||
+1.0f,
|
||||
+1.5f,
|
||||
+2.0f,
|
||||
+3.0f,
|
||||
+4.0f,
|
||||
+6.0f,
|
||||
-0.0f,
|
||||
-0.5f,
|
||||
-1.0f,
|
||||
-1.5f,
|
||||
-2.0f,
|
||||
-3.0f,
|
||||
-4.0f,
|
||||
-6.0f,
|
||||
},
|
||||
dtype);
|
||||
|
||||
auto what = view(reshape(w, {-1, group_size / 8}, s), int8, s);
|
||||
|
||||
auto idx_lo = bitwise_and(what, array(0x0F, int8), s);
|
||||
auto idx_hi = right_shift(what, array(4, int8), s);
|
||||
auto lo = gather(lut, idx_lo, 0, {1}, s);
|
||||
auto hi = gather(lut, idx_hi, 0, {1}, s);
|
||||
what = flatten(concatenate({lo, hi}, -1, s), -2, -1, s);
|
||||
auto exponent = subtract(astype(scales, dtype, s), array(127, dtype), s);
|
||||
exponent = reshape(exponent, {-1, 1}, s);
|
||||
return reshape(
|
||||
multiply(power(array(2.0f, dtype), exponent, s), what, s), wshape, s);
|
||||
auto out_type = dtype.has_value() ? *dtype : bfloat16;
|
||||
auto out = w;
|
||||
if (bits == 4) {
|
||||
auto lut = array(
|
||||
{
|
||||
+0.0f,
|
||||
+0.5f,
|
||||
+1.0f,
|
||||
+1.5f,
|
||||
+2.0f,
|
||||
+3.0f,
|
||||
+4.0f,
|
||||
+6.0f,
|
||||
-0.0f,
|
||||
-0.5f,
|
||||
-1.0f,
|
||||
-1.5f,
|
||||
-2.0f,
|
||||
-3.0f,
|
||||
-4.0f,
|
||||
-6.0f,
|
||||
},
|
||||
out_type);
|
||||
out = view(reshape(out, {-1, 4}, s), int8, s);
|
||||
auto idx_lo = bitwise_and(out, array(0x0F, int8), s);
|
||||
auto idx_hi = right_shift(out, array(4, int8), s);
|
||||
auto lo = gather(lut, idx_lo, 0, {1}, s);
|
||||
auto hi = gather(lut, idx_hi, 0, {1}, s);
|
||||
out = concatenate({lo, hi}, -1, s);
|
||||
} else {
|
||||
out = from_fp8(view(out, uint8, s), out_type, s);
|
||||
}
|
||||
out = reshape(out, {-1, group_size}, s);
|
||||
auto flat_scales = reshape(scales, {-1, 1}, s);
|
||||
if (group_size == 16) {
|
||||
flat_scales = from_fp8(flat_scales, out_type, s);
|
||||
} else {
|
||||
flat_scales =
|
||||
subtract(astype(flat_scales, out_type, s), array(127, out_type), s);
|
||||
flat_scales = power(array(2.0f, out_type), flat_scales, s);
|
||||
}
|
||||
return reshape(multiply(out, flat_scales, s), wshape, s);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1400,6 +1400,7 @@ array dequantize(
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
const std::string& mode = "affine",
|
||||
std::optional<Dtype> dtype = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Convert an E4M3 float8 to the given floating point dtype. */
|
||||
|
||||
@@ -3404,6 +3404,7 @@ std::vector<array> QuantizedMatmul::vjp(
|
||||
group_size_,
|
||||
bits_,
|
||||
quantization_mode_to_string(mode_),
|
||||
std::nullopt,
|
||||
stream());
|
||||
wq = unflatten(wq, -1, {-1, group_size_}, stream());
|
||||
vjps.push_back(sum(multiply(*dsb, wq, stream()), -1, false, stream()));
|
||||
@@ -3558,6 +3559,7 @@ std::vector<array> GatherQMM::vjp(
|
||||
group_size_,
|
||||
bits_,
|
||||
quantization_mode_to_string(mode_),
|
||||
std::nullopt,
|
||||
stream()),
|
||||
-1,
|
||||
{-1, group_size_},
|
||||
|
||||
@@ -4307,10 +4307,11 @@ void init_ops(nb::module_& m) {
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"mode"_a = "affine",
|
||||
"dtype"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', dtype: Optional[Dtype], *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Dequantize the matrix ``w`` using quantization parameters.
|
||||
|
||||
@@ -4323,6 +4324,10 @@ void init_ops(nb::module_& m) {
|
||||
scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
dtype (Dtype, optional): The data type of the dequantized output. If
|
||||
``None`` the return type is inferred from the scales and biases
|
||||
when possible and otherwise defaults to ``bfloat16``.
|
||||
Default: ``None``.
|
||||
mode (str, optional): The quantization mode. Default: ``"affine"``.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -77,6 +77,84 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4")
|
||||
self.assertTrue(mx.all(w_hat == 0))
|
||||
|
||||
def test_mxfp8_quantize_dequantize(self):
|
||||
w = 2 * mx.random.uniform(shape=(512, 32)) - 1
|
||||
w = w.astype(mx.bfloat16)
|
||||
|
||||
# Invalid bits / group size
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, bits=3, group_size=32, mode="mxfp8")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, group_size=32, bits=7, mode="mxfp8")
|
||||
|
||||
w_q, scales = mx.quantize(w, group_size=32, bits=8, mode="mxfp8")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, bits=8, group_size=16, mode="mxfp8")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp8")
|
||||
|
||||
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=8, mode="mxfp8")
|
||||
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-1, atol=1e-2))
|
||||
|
||||
# test quantize/dequantize 0s
|
||||
a = mx.zeros((256, 512))
|
||||
w_q, scales = mx.quantize(a, group_size=32, bits=8, mode="mxfp8")
|
||||
w_hat = mx.dequantize(w_q, scales, group_size=32, bits=8, mode="mxfp8")
|
||||
self.assertTrue(mx.all(w_hat == 0))
|
||||
|
||||
def test_nvfp4_quantize_dequantize(self):
|
||||
lut = mx.array(
|
||||
[
|
||||
+0.0,
|
||||
+0.5,
|
||||
+1.0,
|
||||
+1.5,
|
||||
+2.0,
|
||||
+3.0,
|
||||
+4.0,
|
||||
+6.0,
|
||||
-0.0,
|
||||
-0.5,
|
||||
-1.0,
|
||||
-1.5,
|
||||
-2.0,
|
||||
-3.0,
|
||||
-4.0,
|
||||
-6.0,
|
||||
]
|
||||
)
|
||||
w = lut[mx.random.randint(0, 16, shape=(128, 512))]
|
||||
w = w.reshape(-1, 16)
|
||||
w[:, 0] = 6
|
||||
w = (w + 3e-6).astype(mx.bfloat16)
|
||||
|
||||
# Invalid bits / group size
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, bits=3, group_size=16, mode="nvfp4")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.quantize(w, group_size=64, bits=4, mode="nvfp4")
|
||||
|
||||
w_q, scales = mx.quantize(w, group_size=16, bits=4, mode="nvfp4")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, bits=4, group_size=32, mode="nvfp4")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.dequantize(w_q, scales, group_size=32, bits=4, mode="nvfp4")
|
||||
|
||||
w_hat = mx.dequantize(w_q, scales, group_size=16, bits=4, mode="nvfp4")
|
||||
self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5))
|
||||
|
||||
# test quantize/dequantize 0s
|
||||
a = mx.zeros((256, 512))
|
||||
w_q, scales = mx.quantize(a, group_size=16, bits=4, mode="nvfp4")
|
||||
w_hat = mx.dequantize(w_q, scales, group_size=16, bits=4, mode="nvfp4")
|
||||
self.assertTrue(mx.all(w_hat == 0))
|
||||
|
||||
def test_qmm(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
|
||||
Reference in New Issue
Block a user