Change the argument name to quantization_type

This commit is contained in:
Angelos Katharopoulos 2024-12-16 13:31:34 -08:00
parent f5da489a3c
commit 410ccdbed5
4 changed files with 78 additions and 50 deletions

View File

@ -20,14 +20,14 @@ def qmv_(x, wq1, wq2, q_type):
*wq1, *wq1,
group_size=group_size, group_size=group_size,
bits=bits, bits=bits,
type=q_type, quantization_type=q_type,
) )
x = mx.quantized_matmul( x = mx.quantized_matmul(
x, x,
*wq2, *wq2,
group_size=group_size, group_size=group_size,
bits=bits, bits=bits,
type=q_type, quantization_type=q_type,
) )
return x return x
@ -44,9 +44,9 @@ def time_qmv():
mx.random.seed(3) mx.random.seed(3)
x = mx.random.normal(shape=(1, D)).astype(dtype) x = mx.random.normal(shape=(1, D)).astype(dtype)
w1 = mx.random.normal(shape=(M, D)).astype(dtype) w1 = mx.random.normal(shape=(M, D)).astype(dtype)
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, type="affine") wq1 = mx.quantize(w1, group_size=group_size, bits=bits, quantization_type="affine")
w2 = mx.random.normal(shape=(D, M)).astype(dtype) w2 = mx.random.normal(shape=(D, M)).astype(dtype)
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, type="affine") wq2 = mx.quantize(w2, group_size=group_size, bits=bits, quantization_type="affine")
mx.eval(x, wq1, wq2) mx.eval(x, wq1, wq2)
time_fn(affine_qmv, x, wq1, wq2) time_fn(affine_qmv, x, wq1, wq2)
@ -55,15 +55,19 @@ def time_packed_qmv():
mx.random.seed(3) mx.random.seed(3)
x = mx.random.normal(shape=(1, D)).astype(dtype) x = mx.random.normal(shape=(1, D)).astype(dtype)
w1 = mx.random.normal(shape=(M, D)).astype(dtype) w1 = mx.random.normal(shape=(M, D)).astype(dtype)
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, type="affine-packed") wq1 = mx.quantize(
w1, group_size=group_size, bits=bits, quantization_type="affine-packed"
)
w2 = mx.random.normal(shape=(D, M)).astype(dtype) w2 = mx.random.normal(shape=(D, M)).astype(dtype)
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, type="affine-packed") wq2 = mx.quantize(
w2, group_size=group_size, bits=bits, quantization_type="affine-packed"
)
mx.eval(x, wq1, wq2) mx.eval(x, wq1, wq2)
time_fn(affine_packed_qmv, x, wq1, wq2) time_fn(affine_packed_qmv, x, wq1, wq2)
if __name__ == "__main__": if __name__ == "__main__":
for b in [2, 3, 4, 6, 8]: for b in [2, 4, 8]:
bits = b bits = b
print(f"Bits {bits}:") print(f"Bits {bits}:")
time_qmv() time_qmv()

View File

@ -79,29 +79,29 @@ std::pair<int, int> extract_quantized_matmul_dims(
bool transpose, bool transpose,
int group_size, int group_size,
int bits, int bits,
QuantizationType type) { QuantizationType quantization_type) {
// Check if we have biases as expected // Check if we have biases as expected
switch (type) { switch (quantization_type) {
case QuantizationType::Affine: case QuantizationType::Affine:
if (!biases.has_value()) { if (!biases.has_value()) {
std::ostringstream msg; std::ostringstream msg;
msg << "[" << tag msg << "[" << tag
<< "] The biases argument is required for quantization " << "] The biases argument is required for quantization "
<< "type '" << type << "'"; << "type '" << quantization_type << "'";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
break; break;
case QuantizationType::AffinePacked: case QuantizationType::AffinePacked:
if (biases.has_value()) { if (biases.has_value()) {
std::ostringstream msg; std::ostringstream msg;
msg << "[" << tag << "] Quantization type '" << type msg << "[" << tag << "] Quantization type '" << quantization_type
<< "' does not use " << "' does not use "
<< "biases but biases were provided"; << "biases but biases were provided";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (bits & (bits - 1)) { if (bits & (bits - 1)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[" << tag << "] Quantization type '" << type msg << "[" << tag << "] Quantization type '" << quantization_type
<< "' does not support " << bits << " bits."; << "' does not support " << bits << " bits.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
@ -135,7 +135,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
int weight_dims = w.shape(-1) * 32 / bits; int weight_dims = w.shape(-1) * 32 / bits;
int scales_dims = scales.shape(-1) * group_size; int scales_dims = scales.shape(-1) * group_size;
if (type == QuantizationType::AffinePacked) { if (quantization_type == QuantizationType::AffinePacked) {
scales_dims /= 8; scales_dims /= 8;
weight_dims /= 4; weight_dims /= 4;
} }
@ -147,7 +147,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
<< "w.shape() == " << w.shape() << "w.shape() == " << w.shape()
<< " and scales.shape() == " << scales.shape() << " and scales.shape() == " << scales.shape()
<< " with group_size=" << group_size << ", bits=" << bits << " with group_size=" << group_size << ", bits=" << bits
<< " and type='" << type << "'"; << " and type='" << quantization_type << "'";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
@ -155,7 +155,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
// Calculate the expanded w's dims // Calculate the expanded w's dims
int weight_dims_other = w.shape(-2); int weight_dims_other = w.shape(-2);
if (type == QuantizationType::AffinePacked) { if (quantization_type == QuantizationType::AffinePacked) {
weight_dims_other *= 4; weight_dims_other *= 4;
} }
int w_inner_dims = (transpose) ? weight_dims : weight_dims_other; int w_inner_dims = (transpose) ? weight_dims : weight_dims_other;
@ -3708,7 +3708,7 @@ array quantized_matmul(
bool transpose /* = true */, bool transpose /* = true */,
int group_size /* = 64 */, int group_size /* = 64 */,
int bits /* = 4 */, int bits /* = 4 */,
QuantizationType type /* = QuantizationType::Affine */, QuantizationType quantization_type /* = QuantizationType::Affine */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
// Check and extract the quantized matrix shape against x // Check and extract the quantized matrix shape against x
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
@ -3720,7 +3720,7 @@ array quantized_matmul(
transpose, transpose,
group_size, group_size,
bits, bits,
type); quantization_type);
// QuantizedMatmul handles w.ndim == 2 case. // QuantizedMatmul handles w.ndim == 2 case.
if (x.ndim() > 2 && w.ndim() > 2) { if (x.ndim() > 2 && w.ndim() > 2) {
@ -3779,7 +3779,7 @@ array quantized_matmul(
std::move(out_shape), std::move(out_shape),
dtype, dtype,
std::make_shared<QuantizedMatmul>( std::make_shared<QuantizedMatmul>(
to_stream(s), type, group_size, bits, transpose), to_stream(s), quantization_type, group_size, bits, transpose),
std::move(inputs)); std::move(inputs));
} }
@ -3787,16 +3787,16 @@ std::tuple<array, array, std::optional<array>> quantize(
const array& w, const array& w,
int group_size /* = 64 */, int group_size /* = 64 */,
int bits /* = 4 */, int bits /* = 4 */,
QuantizationType type /* = QuantizationType::Affine */, QuantizationType quantization_type /* = QuantizationType::Affine */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
switch (type) { switch (quantization_type) {
case QuantizationType::Affine: case QuantizationType::Affine:
return fast::affine_quantize(w, group_size, bits, s); return fast::affine_quantize(w, group_size, bits, s);
case QuantizationType::AffinePacked: { case QuantizationType::AffinePacked: {
if (bits & (bits - 1)) { if (bits & (bits - 1)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[quantize] Quantization type '" << type << "' does not support " msg << "[quantize] Quantization type '" << quantization_type
<< bits << " bits."; << "' does not support " << bits << " bits.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s); auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s);
@ -3822,7 +3822,7 @@ array dequantize(
const std::optional<array>& biases, const std::optional<array>& biases,
int group_size /* = 64 */, int group_size /* = 64 */,
int bits /* = 4 */, int bits /* = 4 */,
QuantizationType type /* = QuantizationType::Affine */, QuantizationType quantization_type /* = QuantizationType::Affine */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
return fast::affine_dequantize( return fast::affine_dequantize(
w, scales, biases.value(), group_size, bits, s); w, scales, biases.value(), group_size, bits, s);
@ -3838,21 +3838,38 @@ array gather_qmm(
bool transpose /* = true */, bool transpose /* = true */,
int group_size /* = 64 */, int group_size /* = 64 */,
int bits /* = 4 */, int bits /* = 4 */,
QuantizationType type /* = QuantizationType::Affine */, QuantizationType quantization_type /* = QuantizationType::Affine */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
if (!lhs_indices_ && !rhs_indices_) { if (!lhs_indices_ && !rhs_indices_) {
return quantized_matmul( return quantized_matmul(
x, w, scales, biases, transpose, group_size, bits, type, s); x,
w,
scales,
biases,
transpose,
group_size,
bits,
quantization_type,
s);
} }
if (type != QuantizationType::Affine) { if (quantization_type != QuantizationType::Affine) {
std::ostringstream msg; std::ostringstream msg;
msg << "[gather_qmm] Only quantization type '" << type << "' supported"; msg << "[gather_qmm] Only quantization type '" << quantization_type
<< "' supported";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
"gather_qmm", x, w, scales, biases, transpose, group_size, bits, type); "gather_qmm",
x,
w,
scales,
biases,
transpose,
group_size,
bits,
quantization_type);
// Extract indices and broadcast them // Extract indices and broadcast them
array lhs_indices = indices_or_default(lhs_indices_, x, s); array lhs_indices = indices_or_default(lhs_indices_, x, s);
@ -3874,7 +3891,7 @@ array gather_qmm(
std::move(out_shape), std::move(out_shape),
out_type, out_type,
std::make_shared<GatherQMM>( std::make_shared<GatherQMM>(
to_stream(s), type, group_size, bits, transpose), to_stream(s), quantization_type, group_size, bits, transpose),
{astype(x, out_type, s), {astype(x, out_type, s),
w, w,
astype(scales, out_type, s), astype(scales, out_type, s),

View File

@ -1286,7 +1286,7 @@ array quantized_matmul(
bool transpose = true, bool transpose = true,
int group_size = 64, int group_size = 64,
int bits = 4, int bits = 4,
QuantizationType type = QuantizationType::Affine, QuantizationType quantization_type = QuantizationType::Affine,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Quantize a matrix along its last axis */ /** Quantize a matrix along its last axis */
@ -1294,7 +1294,7 @@ std::tuple<array, array, std::optional<array>> quantize(
const array& w, const array& w,
int group_size = 64, int group_size = 64,
int bits = 4, int bits = 4,
QuantizationType type = QuantizationType::Affine, QuantizationType quantization_type = QuantizationType::Affine,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Dequantize a matrix produced by quantize() */ /** Dequantize a matrix produced by quantize() */
@ -1304,7 +1304,7 @@ array dequantize(
const std::optional<array>& biases, const std::optional<array>& biases,
int group_size = 64, int group_size = 64,
int bits = 4, int bits = 4,
QuantizationType type = QuantizationType::Affine, QuantizationType quantization_type = QuantizationType::Affine,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Compute matrix products with matrix-level gather. */ /** Compute matrix products with matrix-level gather. */
@ -1318,7 +1318,7 @@ array gather_qmm(
bool transpose = true, bool transpose = true,
int group_size = 64, int group_size = 64,
int bits = 4, int bits = 4,
QuantizationType type = QuantizationType::Affine, QuantizationType quantization_type = QuantizationType::Affine,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Returns a contraction of a and b over multiple dimensions. */ /** Returns a contraction of a and b over multiple dimensions. */

View File

@ -4025,7 +4025,7 @@ void init_ops(nb::module_& m) {
bool transpose, bool transpose,
int group_size, int group_size,
int bits, int bits,
const std::string& type, const std::string& quantization_type,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
return mx::quantized_matmul( return mx::quantized_matmul(
std::move(x), std::move(x),
@ -4035,7 +4035,7 @@ void init_ops(nb::module_& m) {
transpose, transpose,
group_size, group_size,
bits, bits,
mx::from_string(type), mx::from_string(quantization_type),
s); s);
}, },
nb::arg(), nb::arg(),
@ -4045,11 +4045,11 @@ void init_ops(nb::module_& m) {
"transpose"_a = true, "transpose"_a = true,
"group_size"_a = 64, "group_size"_a = 64,
"bits"_a = 4, "bits"_a = 4,
"type"_a = "affine", "quantization_type"_a = "affine",
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array], transpose: bool = True, group_size: int = 64, bits: int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), "def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array], transpose: bool = True, group_size: int = 64, bits: int = 4, quantization_type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Perform the matrix multiplication with the quantized matrix ``w``. The Perform the matrix multiplication with the quantized matrix ``w``. The
quantization uses one floating point scale and bias per ``group_size`` of quantization uses one floating point scale and bias per ``group_size`` of
@ -4069,7 +4069,7 @@ void init_ops(nb::module_& m) {
shares a scale and bias. Default: ``64``. shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``. ``w``. Default: ``4``.
type (str, optional): The type of quantization used for the matrix. quantization_type (str, optional): The type of quantization used for the matrix.
It can be 'affine' or 'affine-packed'. It can be 'affine' or 'affine-packed'.
Returns: Returns:
@ -4080,18 +4080,19 @@ void init_ops(nb::module_& m) {
[](const mx::array& w, [](const mx::array& w,
int group_size, int group_size,
int bits, int bits,
const std::string& type, const std::string& quantization_type,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
return mx::quantize(w, group_size, bits, mx::from_string(type), s); return mx::quantize(
w, group_size, bits, mx::from_string(quantization_type), s);
}, },
nb::arg(), nb::arg(),
"group_size"_a = 64, "group_size"_a = 64,
"bits"_a = 4, "bits"_a = 4,
"type"_a = "affine", "quantization_type"_a = "affine",
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, Optional[array]]"), "def quantize(w: array, /, group_size: int = 64, bits : int = 4, quantization_type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, Optional[array]]"),
R"pbdoc( R"pbdoc(
Quantize the matrix ``w`` using ``bits`` bits per element. Quantize the matrix ``w`` using ``bits`` bits per element.
@ -4133,7 +4134,7 @@ void init_ops(nb::module_& m) {
scale and bias. Default: ``64``. scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element of bits (int, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. Default: ``4``. ``w`` in the returned quantized matrix. Default: ``4``.
type (str, optional): The type of quantization used for the matrix. quantization_type (str, optional): The type of quantization used for the matrix.
It can be 'affine' or 'affine-packed'. It can be 'affine' or 'affine-packed'.
Returns: Returns:
@ -4152,21 +4153,27 @@ void init_ops(nb::module_& m) {
const std::optional<mx::array>& biases, const std::optional<mx::array>& biases,
int group_size, int group_size,
int bits, int bits,
const std::string& type, const std::string& quantization_type,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
return mx::dequantize( return mx::dequantize(
wq, scales, biases, group_size, bits, mx::from_string(type), s); wq,
scales,
biases,
group_size,
bits,
mx::from_string(quantization_type),
s);
}, },
nb::arg(), nb::arg(),
"scales"_a, "scales"_a,
"biases"_a, "biases"_a,
"group_size"_a = 64, "group_size"_a = 64,
"bits"_a = 4, "bits"_a = 4,
"type"_a = "affine", "quantization_type"_a = "affine",
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def dequantize(w: array, /, scales: array, biases: Optional[array], group_size: int = 64, bits: int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), "def dequantize(w: array, /, scales: array, biases: Optional[array], group_size: int = 64, bits: int = 4, quantization_type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Dequantize the matrix ``w`` using the provided ``scales`` and Dequantize the matrix ``w`` using the provided ``scales`` and
``biases`` and the ``group_size`` and ``bits`` configuration. ``biases`` and the ``group_size`` and ``bits`` configuration.
@ -4187,7 +4194,7 @@ void init_ops(nb::module_& m) {
scale and bias. Default: ``64``. scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``. ``w``. Default: ``4``.
type (str, optional): The type of quantization used for the matrix. quantization_type (str, optional): The type of quantization used for the matrix.
It can be 'affine' or 'affine-packed'. It can be 'affine' or 'affine-packed'.
Returns: Returns:
@ -4205,7 +4212,7 @@ void init_ops(nb::module_& m) {
"transpose"_a = true, "transpose"_a = true,
"group_size"_a = 64, "group_size"_a = 64,
"bits"_a = 4, "bits"_a = 4,
"type"_a = "affine", "quantization_type"_a = "affine",
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
@ -4235,7 +4242,7 @@ void init_ops(nb::module_& m) {
shares a scale and bias. Default: ``64``. shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``. ``w``. Default: ``4``.
type (str, optional): The type of quantization used for the matrix. quantization_type (str, optional): The type of quantization used for the matrix.
It can be 'affine' or 'affine-packed'. It can be 'affine' or 'affine-packed'.
Returns: Returns: