Change the argument name to quantization_type

This commit is contained in:
Angelos Katharopoulos 2024-12-16 13:31:34 -08:00
parent f5da489a3c
commit 410ccdbed5
4 changed files with 78 additions and 50 deletions

View File

@ -20,14 +20,14 @@ def qmv_(x, wq1, wq2, q_type):
*wq1,
group_size=group_size,
bits=bits,
type=q_type,
quantization_type=q_type,
)
x = mx.quantized_matmul(
x,
*wq2,
group_size=group_size,
bits=bits,
type=q_type,
quantization_type=q_type,
)
return x
@ -44,9 +44,9 @@ def time_qmv():
mx.random.seed(3)
x = mx.random.normal(shape=(1, D)).astype(dtype)
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, type="affine")
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, quantization_type="affine")
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, type="affine")
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, quantization_type="affine")
mx.eval(x, wq1, wq2)
time_fn(affine_qmv, x, wq1, wq2)
@ -55,15 +55,19 @@ def time_packed_qmv():
mx.random.seed(3)
x = mx.random.normal(shape=(1, D)).astype(dtype)
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, type="affine-packed")
wq1 = mx.quantize(
w1, group_size=group_size, bits=bits, quantization_type="affine-packed"
)
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, type="affine-packed")
wq2 = mx.quantize(
w2, group_size=group_size, bits=bits, quantization_type="affine-packed"
)
mx.eval(x, wq1, wq2)
time_fn(affine_packed_qmv, x, wq1, wq2)
if __name__ == "__main__":
for b in [2, 3, 4, 6, 8]:
for b in [2, 4, 8]:
bits = b
print(f"Bits {bits}:")
time_qmv()

View File

@ -79,29 +79,29 @@ std::pair<int, int> extract_quantized_matmul_dims(
bool transpose,
int group_size,
int bits,
QuantizationType type) {
QuantizationType quantization_type) {
// Check if we have biases as expected
switch (type) {
switch (quantization_type) {
case QuantizationType::Affine:
if (!biases.has_value()) {
std::ostringstream msg;
msg << "[" << tag
<< "] The biases argument is required for quantization "
<< "type '" << type << "'";
<< "type '" << quantization_type << "'";
throw std::invalid_argument(msg.str());
}
break;
case QuantizationType::AffinePacked:
if (biases.has_value()) {
std::ostringstream msg;
msg << "[" << tag << "] Quantization type '" << type
msg << "[" << tag << "] Quantization type '" << quantization_type
<< "' does not use "
<< "biases but biases were provided";
throw std::invalid_argument(msg.str());
}
if (bits & (bits - 1)) {
std::ostringstream msg;
msg << "[" << tag << "] Quantization type '" << type
msg << "[" << tag << "] Quantization type '" << quantization_type
<< "' does not support " << bits << " bits.";
throw std::invalid_argument(msg.str());
}
@ -135,7 +135,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
int weight_dims = w.shape(-1) * 32 / bits;
int scales_dims = scales.shape(-1) * group_size;
if (type == QuantizationType::AffinePacked) {
if (quantization_type == QuantizationType::AffinePacked) {
scales_dims /= 8;
weight_dims /= 4;
}
@ -147,7 +147,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
<< "w.shape() == " << w.shape()
<< " and scales.shape() == " << scales.shape()
<< " with group_size=" << group_size << ", bits=" << bits
<< " and type='" << type << "'";
<< " and type='" << quantization_type << "'";
throw std::invalid_argument(msg.str());
}
@ -155,7 +155,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
// Calculate the expanded w's dims
int weight_dims_other = w.shape(-2);
if (type == QuantizationType::AffinePacked) {
if (quantization_type == QuantizationType::AffinePacked) {
weight_dims_other *= 4;
}
int w_inner_dims = (transpose) ? weight_dims : weight_dims_other;
@ -3708,7 +3708,7 @@ array quantized_matmul(
bool transpose /* = true */,
int group_size /* = 64 */,
int bits /* = 4 */,
QuantizationType type /* = QuantizationType::Affine */,
QuantizationType quantization_type /* = QuantizationType::Affine */,
StreamOrDevice s /* = {} */) {
// Check and extract the quantized matrix shape against x
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
@ -3720,7 +3720,7 @@ array quantized_matmul(
transpose,
group_size,
bits,
type);
quantization_type);
// QuantizedMatmul handles w.ndim == 2 case.
if (x.ndim() > 2 && w.ndim() > 2) {
@ -3779,7 +3779,7 @@ array quantized_matmul(
std::move(out_shape),
dtype,
std::make_shared<QuantizedMatmul>(
to_stream(s), type, group_size, bits, transpose),
to_stream(s), quantization_type, group_size, bits, transpose),
std::move(inputs));
}
@ -3787,16 +3787,16 @@ std::tuple<array, array, std::optional<array>> quantize(
const array& w,
int group_size /* = 64 */,
int bits /* = 4 */,
QuantizationType type /* = QuantizationType::Affine */,
QuantizationType quantization_type /* = QuantizationType::Affine */,
StreamOrDevice s /* = {} */) {
switch (type) {
switch (quantization_type) {
case QuantizationType::Affine:
return fast::affine_quantize(w, group_size, bits, s);
case QuantizationType::AffinePacked: {
if (bits & (bits - 1)) {
std::ostringstream msg;
msg << "[quantize] Quantization type '" << type << "' does not support "
<< bits << " bits.";
msg << "[quantize] Quantization type '" << quantization_type
<< "' does not support " << bits << " bits.";
throw std::invalid_argument(msg.str());
}
auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s);
@ -3822,7 +3822,7 @@ array dequantize(
const std::optional<array>& biases,
int group_size /* = 64 */,
int bits /* = 4 */,
QuantizationType type /* = QuantizationType::Affine */,
QuantizationType quantization_type /* = QuantizationType::Affine */,
StreamOrDevice s /* = {} */) {
return fast::affine_dequantize(
w, scales, biases.value(), group_size, bits, s);
@ -3838,21 +3838,38 @@ array gather_qmm(
bool transpose /* = true */,
int group_size /* = 64 */,
int bits /* = 4 */,
QuantizationType type /* = QuantizationType::Affine */,
QuantizationType quantization_type /* = QuantizationType::Affine */,
StreamOrDevice s /* = {} */) {
if (!lhs_indices_ && !rhs_indices_) {
return quantized_matmul(
x, w, scales, biases, transpose, group_size, bits, type, s);
x,
w,
scales,
biases,
transpose,
group_size,
bits,
quantization_type,
s);
}
if (type != QuantizationType::Affine) {
if (quantization_type != QuantizationType::Affine) {
std::ostringstream msg;
msg << "[gather_qmm] Only quantization type '" << type << "' supported";
msg << "[gather_qmm] Only quantization type '" << quantization_type
<< "' supported";
throw std::invalid_argument(msg.str());
}
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
"gather_qmm", x, w, scales, biases, transpose, group_size, bits, type);
"gather_qmm",
x,
w,
scales,
biases,
transpose,
group_size,
bits,
quantization_type);
// Extract indices and broadcast them
array lhs_indices = indices_or_default(lhs_indices_, x, s);
@ -3874,7 +3891,7 @@ array gather_qmm(
std::move(out_shape),
out_type,
std::make_shared<GatherQMM>(
to_stream(s), type, group_size, bits, transpose),
to_stream(s), quantization_type, group_size, bits, transpose),
{astype(x, out_type, s),
w,
astype(scales, out_type, s),

View File

@ -1286,7 +1286,7 @@ array quantized_matmul(
bool transpose = true,
int group_size = 64,
int bits = 4,
QuantizationType type = QuantizationType::Affine,
QuantizationType quantization_type = QuantizationType::Affine,
StreamOrDevice s = {});
/** Quantize a matrix along its last axis */
@ -1294,7 +1294,7 @@ std::tuple<array, array, std::optional<array>> quantize(
const array& w,
int group_size = 64,
int bits = 4,
QuantizationType type = QuantizationType::Affine,
QuantizationType quantization_type = QuantizationType::Affine,
StreamOrDevice s = {});
/** Dequantize a matrix produced by quantize() */
@ -1304,7 +1304,7 @@ array dequantize(
const std::optional<array>& biases,
int group_size = 64,
int bits = 4,
QuantizationType type = QuantizationType::Affine,
QuantizationType quantization_type = QuantizationType::Affine,
StreamOrDevice s = {});
/** Compute matrix products with matrix-level gather. */
@ -1318,7 +1318,7 @@ array gather_qmm(
bool transpose = true,
int group_size = 64,
int bits = 4,
QuantizationType type = QuantizationType::Affine,
QuantizationType quantization_type = QuantizationType::Affine,
StreamOrDevice s = {});
/** Returns a contraction of a and b over multiple dimensions. */

View File

@ -4025,7 +4025,7 @@ void init_ops(nb::module_& m) {
bool transpose,
int group_size,
int bits,
const std::string& type,
const std::string& quantization_type,
mx::StreamOrDevice s) {
return mx::quantized_matmul(
std::move(x),
@ -4035,7 +4035,7 @@ void init_ops(nb::module_& m) {
transpose,
group_size,
bits,
mx::from_string(type),
mx::from_string(quantization_type),
s);
},
nb::arg(),
@ -4045,11 +4045,11 @@ void init_ops(nb::module_& m) {
"transpose"_a = true,
"group_size"_a = 64,
"bits"_a = 4,
"type"_a = "affine",
"quantization_type"_a = "affine",
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array], transpose: bool = True, group_size: int = 64, bits: int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
"def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array], transpose: bool = True, group_size: int = 64, bits: int = 4, quantization_type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform the matrix multiplication with the quantized matrix ``w``. The
quantization uses one floating point scale and bias per ``group_size`` of
@ -4069,7 +4069,7 @@ void init_ops(nb::module_& m) {
shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
type (str, optional): The type of quantization used for the matrix.
quantization_type (str, optional): The type of quantization used for the matrix.
It can be 'affine' or 'affine-packed'.
Returns:
@ -4080,18 +4080,19 @@ void init_ops(nb::module_& m) {
[](const mx::array& w,
int group_size,
int bits,
const std::string& type,
const std::string& quantization_type,
mx::StreamOrDevice s) {
return mx::quantize(w, group_size, bits, mx::from_string(type), s);
return mx::quantize(
w, group_size, bits, mx::from_string(quantization_type), s);
},
nb::arg(),
"group_size"_a = 64,
"bits"_a = 4,
"type"_a = "affine",
"quantization_type"_a = "affine",
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, Optional[array]]"),
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, quantization_type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, Optional[array]]"),
R"pbdoc(
Quantize the matrix ``w`` using ``bits`` bits per element.
@ -4133,7 +4134,7 @@ void init_ops(nb::module_& m) {
scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. Default: ``4``.
type (str, optional): The type of quantization used for the matrix.
quantization_type (str, optional): The type of quantization used for the matrix.
It can be 'affine' or 'affine-packed'.
Returns:
@ -4152,21 +4153,27 @@ void init_ops(nb::module_& m) {
const std::optional<mx::array>& biases,
int group_size,
int bits,
const std::string& type,
const std::string& quantization_type,
mx::StreamOrDevice s) {
return mx::dequantize(
wq, scales, biases, group_size, bits, mx::from_string(type), s);
wq,
scales,
biases,
group_size,
bits,
mx::from_string(quantization_type),
s);
},
nb::arg(),
"scales"_a,
"biases"_a,
"group_size"_a = 64,
"bits"_a = 4,
"type"_a = "affine",
"quantization_type"_a = "affine",
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def dequantize(w: array, /, scales: array, biases: Optional[array], group_size: int = 64, bits: int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
"def dequantize(w: array, /, scales: array, biases: Optional[array], group_size: int = 64, bits: int = 4, quantization_type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Dequantize the matrix ``w`` using the provided ``scales`` and
``biases`` and the ``group_size`` and ``bits`` configuration.
@ -4187,7 +4194,7 @@ void init_ops(nb::module_& m) {
scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
type (str, optional): The type of quantization used for the matrix.
quantization_type (str, optional): The type of quantization used for the matrix.
It can be 'affine' or 'affine-packed'.
Returns:
@ -4205,7 +4212,7 @@ void init_ops(nb::module_& m) {
"transpose"_a = true,
"group_size"_a = 64,
"bits"_a = 4,
"type"_a = "affine",
"quantization_type"_a = "affine",
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
@ -4235,7 +4242,7 @@ void init_ops(nb::module_& m) {
shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
type (str, optional): The type of quantization used for the matrix.
quantization_type (str, optional): The type of quantization used for the matrix.
It can be 'affine' or 'affine-packed'.
Returns: