mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-14 20:41:13 +08:00
Change the argument name to quantization_type
This commit is contained in:
parent
f5da489a3c
commit
410ccdbed5
@ -20,14 +20,14 @@ def qmv_(x, wq1, wq2, q_type):
|
||||
*wq1,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
type=q_type,
|
||||
quantization_type=q_type,
|
||||
)
|
||||
x = mx.quantized_matmul(
|
||||
x,
|
||||
*wq2,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
type=q_type,
|
||||
quantization_type=q_type,
|
||||
)
|
||||
return x
|
||||
|
||||
@ -44,9 +44,9 @@ def time_qmv():
|
||||
mx.random.seed(3)
|
||||
x = mx.random.normal(shape=(1, D)).astype(dtype)
|
||||
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
|
||||
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, type="affine")
|
||||
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, quantization_type="affine")
|
||||
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
|
||||
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, type="affine")
|
||||
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, quantization_type="affine")
|
||||
mx.eval(x, wq1, wq2)
|
||||
time_fn(affine_qmv, x, wq1, wq2)
|
||||
|
||||
@ -55,15 +55,19 @@ def time_packed_qmv():
|
||||
mx.random.seed(3)
|
||||
x = mx.random.normal(shape=(1, D)).astype(dtype)
|
||||
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
|
||||
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, type="affine-packed")
|
||||
wq1 = mx.quantize(
|
||||
w1, group_size=group_size, bits=bits, quantization_type="affine-packed"
|
||||
)
|
||||
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
|
||||
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, type="affine-packed")
|
||||
wq2 = mx.quantize(
|
||||
w2, group_size=group_size, bits=bits, quantization_type="affine-packed"
|
||||
)
|
||||
mx.eval(x, wq1, wq2)
|
||||
time_fn(affine_packed_qmv, x, wq1, wq2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for b in [2, 3, 4, 6, 8]:
|
||||
for b in [2, 4, 8]:
|
||||
bits = b
|
||||
print(f"Bits {bits}:")
|
||||
time_qmv()
|
||||
|
61
mlx/ops.cpp
61
mlx/ops.cpp
@ -79,29 +79,29 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
bool transpose,
|
||||
int group_size,
|
||||
int bits,
|
||||
QuantizationType type) {
|
||||
QuantizationType quantization_type) {
|
||||
// Check if we have biases as expected
|
||||
switch (type) {
|
||||
switch (quantization_type) {
|
||||
case QuantizationType::Affine:
|
||||
if (!biases.has_value()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag
|
||||
<< "] The biases argument is required for quantization "
|
||||
<< "type '" << type << "'";
|
||||
<< "type '" << quantization_type << "'";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
break;
|
||||
case QuantizationType::AffinePacked:
|
||||
if (biases.has_value()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] Quantization type '" << type
|
||||
msg << "[" << tag << "] Quantization type '" << quantization_type
|
||||
<< "' does not use "
|
||||
<< "biases but biases were provided";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (bits & (bits - 1)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] Quantization type '" << type
|
||||
msg << "[" << tag << "] Quantization type '" << quantization_type
|
||||
<< "' does not support " << bits << " bits.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
@ -135,7 +135,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
|
||||
int weight_dims = w.shape(-1) * 32 / bits;
|
||||
int scales_dims = scales.shape(-1) * group_size;
|
||||
if (type == QuantizationType::AffinePacked) {
|
||||
if (quantization_type == QuantizationType::AffinePacked) {
|
||||
scales_dims /= 8;
|
||||
weight_dims /= 4;
|
||||
}
|
||||
@ -147,7 +147,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
<< "w.shape() == " << w.shape()
|
||||
<< " and scales.shape() == " << scales.shape()
|
||||
<< " with group_size=" << group_size << ", bits=" << bits
|
||||
<< " and type='" << type << "'";
|
||||
<< " and type='" << quantization_type << "'";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
@ -155,7 +155,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
|
||||
// Calculate the expanded w's dims
|
||||
int weight_dims_other = w.shape(-2);
|
||||
if (type == QuantizationType::AffinePacked) {
|
||||
if (quantization_type == QuantizationType::AffinePacked) {
|
||||
weight_dims_other *= 4;
|
||||
}
|
||||
int w_inner_dims = (transpose) ? weight_dims : weight_dims_other;
|
||||
@ -3708,7 +3708,7 @@ array quantized_matmul(
|
||||
bool transpose /* = true */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
QuantizationType type /* = QuantizationType::Affine */,
|
||||
QuantizationType quantization_type /* = QuantizationType::Affine */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// Check and extract the quantized matrix shape against x
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
@ -3720,7 +3720,7 @@ array quantized_matmul(
|
||||
transpose,
|
||||
group_size,
|
||||
bits,
|
||||
type);
|
||||
quantization_type);
|
||||
|
||||
// QuantizedMatmul handles w.ndim == 2 case.
|
||||
if (x.ndim() > 2 && w.ndim() > 2) {
|
||||
@ -3779,7 +3779,7 @@ array quantized_matmul(
|
||||
std::move(out_shape),
|
||||
dtype,
|
||||
std::make_shared<QuantizedMatmul>(
|
||||
to_stream(s), type, group_size, bits, transpose),
|
||||
to_stream(s), quantization_type, group_size, bits, transpose),
|
||||
std::move(inputs));
|
||||
}
|
||||
|
||||
@ -3787,16 +3787,16 @@ std::tuple<array, array, std::optional<array>> quantize(
|
||||
const array& w,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
QuantizationType type /* = QuantizationType::Affine */,
|
||||
QuantizationType quantization_type /* = QuantizationType::Affine */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
switch (type) {
|
||||
switch (quantization_type) {
|
||||
case QuantizationType::Affine:
|
||||
return fast::affine_quantize(w, group_size, bits, s);
|
||||
case QuantizationType::AffinePacked: {
|
||||
if (bits & (bits - 1)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] Quantization type '" << type << "' does not support "
|
||||
<< bits << " bits.";
|
||||
msg << "[quantize] Quantization type '" << quantization_type
|
||||
<< "' does not support " << bits << " bits.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s);
|
||||
@ -3822,7 +3822,7 @@ array dequantize(
|
||||
const std::optional<array>& biases,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
QuantizationType type /* = QuantizationType::Affine */,
|
||||
QuantizationType quantization_type /* = QuantizationType::Affine */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return fast::affine_dequantize(
|
||||
w, scales, biases.value(), group_size, bits, s);
|
||||
@ -3838,21 +3838,38 @@ array gather_qmm(
|
||||
bool transpose /* = true */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
QuantizationType type /* = QuantizationType::Affine */,
|
||||
QuantizationType quantization_type /* = QuantizationType::Affine */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (!lhs_indices_ && !rhs_indices_) {
|
||||
return quantized_matmul(
|
||||
x, w, scales, biases, transpose, group_size, bits, type, s);
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
transpose,
|
||||
group_size,
|
||||
bits,
|
||||
quantization_type,
|
||||
s);
|
||||
}
|
||||
|
||||
if (type != QuantizationType::Affine) {
|
||||
if (quantization_type != QuantizationType::Affine) {
|
||||
std::ostringstream msg;
|
||||
msg << "[gather_qmm] Only quantization type '" << type << "' supported";
|
||||
msg << "[gather_qmm] Only quantization type '" << quantization_type
|
||||
<< "' supported";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
"gather_qmm", x, w, scales, biases, transpose, group_size, bits, type);
|
||||
"gather_qmm",
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
transpose,
|
||||
group_size,
|
||||
bits,
|
||||
quantization_type);
|
||||
|
||||
// Extract indices and broadcast them
|
||||
array lhs_indices = indices_or_default(lhs_indices_, x, s);
|
||||
@ -3874,7 +3891,7 @@ array gather_qmm(
|
||||
std::move(out_shape),
|
||||
out_type,
|
||||
std::make_shared<GatherQMM>(
|
||||
to_stream(s), type, group_size, bits, transpose),
|
||||
to_stream(s), quantization_type, group_size, bits, transpose),
|
||||
{astype(x, out_type, s),
|
||||
w,
|
||||
astype(scales, out_type, s),
|
||||
|
@ -1286,7 +1286,7 @@ array quantized_matmul(
|
||||
bool transpose = true,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
QuantizationType type = QuantizationType::Affine,
|
||||
QuantizationType quantization_type = QuantizationType::Affine,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Quantize a matrix along its last axis */
|
||||
@ -1294,7 +1294,7 @@ std::tuple<array, array, std::optional<array>> quantize(
|
||||
const array& w,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
QuantizationType type = QuantizationType::Affine,
|
||||
QuantizationType quantization_type = QuantizationType::Affine,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Dequantize a matrix produced by quantize() */
|
||||
@ -1304,7 +1304,7 @@ array dequantize(
|
||||
const std::optional<array>& biases,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
QuantizationType type = QuantizationType::Affine,
|
||||
QuantizationType quantization_type = QuantizationType::Affine,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Compute matrix products with matrix-level gather. */
|
||||
@ -1318,7 +1318,7 @@ array gather_qmm(
|
||||
bool transpose = true,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
QuantizationType type = QuantizationType::Affine,
|
||||
QuantizationType quantization_type = QuantizationType::Affine,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Returns a contraction of a and b over multiple dimensions. */
|
||||
|
@ -4025,7 +4025,7 @@ void init_ops(nb::module_& m) {
|
||||
bool transpose,
|
||||
int group_size,
|
||||
int bits,
|
||||
const std::string& type,
|
||||
const std::string& quantization_type,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::quantized_matmul(
|
||||
std::move(x),
|
||||
@ -4035,7 +4035,7 @@ void init_ops(nb::module_& m) {
|
||||
transpose,
|
||||
group_size,
|
||||
bits,
|
||||
mx::from_string(type),
|
||||
mx::from_string(quantization_type),
|
||||
s);
|
||||
},
|
||||
nb::arg(),
|
||||
@ -4045,11 +4045,11 @@ void init_ops(nb::module_& m) {
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"type"_a = "affine",
|
||||
"quantization_type"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array], transpose: bool = True, group_size: int = 64, bits: int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array], transpose: bool = True, group_size: int = 64, bits: int = 4, quantization_type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform the matrix multiplication with the quantized matrix ``w``. The
|
||||
quantization uses one floating point scale and bias per ``group_size`` of
|
||||
@ -4069,7 +4069,7 @@ void init_ops(nb::module_& m) {
|
||||
shares a scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
type (str, optional): The type of quantization used for the matrix.
|
||||
quantization_type (str, optional): The type of quantization used for the matrix.
|
||||
It can be 'affine' or 'affine-packed'.
|
||||
|
||||
Returns:
|
||||
@ -4080,18 +4080,19 @@ void init_ops(nb::module_& m) {
|
||||
[](const mx::array& w,
|
||||
int group_size,
|
||||
int bits,
|
||||
const std::string& type,
|
||||
const std::string& quantization_type,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::quantize(w, group_size, bits, mx::from_string(type), s);
|
||||
return mx::quantize(
|
||||
w, group_size, bits, mx::from_string(quantization_type), s);
|
||||
},
|
||||
nb::arg(),
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"type"_a = "affine",
|
||||
"quantization_type"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, Optional[array]]"),
|
||||
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, quantization_type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, Optional[array]]"),
|
||||
R"pbdoc(
|
||||
Quantize the matrix ``w`` using ``bits`` bits per element.
|
||||
|
||||
@ -4133,7 +4134,7 @@ void init_ops(nb::module_& m) {
|
||||
scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element of
|
||||
``w`` in the returned quantized matrix. Default: ``4``.
|
||||
type (str, optional): The type of quantization used for the matrix.
|
||||
quantization_type (str, optional): The type of quantization used for the matrix.
|
||||
It can be 'affine' or 'affine-packed'.
|
||||
|
||||
Returns:
|
||||
@ -4152,21 +4153,27 @@ void init_ops(nb::module_& m) {
|
||||
const std::optional<mx::array>& biases,
|
||||
int group_size,
|
||||
int bits,
|
||||
const std::string& type,
|
||||
const std::string& quantization_type,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::dequantize(
|
||||
wq, scales, biases, group_size, bits, mx::from_string(type), s);
|
||||
wq,
|
||||
scales,
|
||||
biases,
|
||||
group_size,
|
||||
bits,
|
||||
mx::from_string(quantization_type),
|
||||
s);
|
||||
},
|
||||
nb::arg(),
|
||||
"scales"_a,
|
||||
"biases"_a,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"type"_a = "affine",
|
||||
"quantization_type"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def dequantize(w: array, /, scales: array, biases: Optional[array], group_size: int = 64, bits: int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def dequantize(w: array, /, scales: array, biases: Optional[array], group_size: int = 64, bits: int = 4, quantization_type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Dequantize the matrix ``w`` using the provided ``scales`` and
|
||||
``biases`` and the ``group_size`` and ``bits`` configuration.
|
||||
@ -4187,7 +4194,7 @@ void init_ops(nb::module_& m) {
|
||||
scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
type (str, optional): The type of quantization used for the matrix.
|
||||
quantization_type (str, optional): The type of quantization used for the matrix.
|
||||
It can be 'affine' or 'affine-packed'.
|
||||
|
||||
Returns:
|
||||
@ -4205,7 +4212,7 @@ void init_ops(nb::module_& m) {
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"type"_a = "affine",
|
||||
"quantization_type"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
@ -4235,7 +4242,7 @@ void init_ops(nb::module_& m) {
|
||||
shares a scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
type (str, optional): The type of quantization used for the matrix.
|
||||
quantization_type (str, optional): The type of quantization used for the matrix.
|
||||
It can be 'affine' or 'affine-packed'.
|
||||
|
||||
Returns:
|
||||
|
Loading…
Reference in New Issue
Block a user