mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-14 20:41:13 +08:00
Change the argument name to quantization_type
This commit is contained in:
parent
f5da489a3c
commit
410ccdbed5
@ -20,14 +20,14 @@ def qmv_(x, wq1, wq2, q_type):
|
|||||||
*wq1,
|
*wq1,
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
bits=bits,
|
bits=bits,
|
||||||
type=q_type,
|
quantization_type=q_type,
|
||||||
)
|
)
|
||||||
x = mx.quantized_matmul(
|
x = mx.quantized_matmul(
|
||||||
x,
|
x,
|
||||||
*wq2,
|
*wq2,
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
bits=bits,
|
bits=bits,
|
||||||
type=q_type,
|
quantization_type=q_type,
|
||||||
)
|
)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -44,9 +44,9 @@ def time_qmv():
|
|||||||
mx.random.seed(3)
|
mx.random.seed(3)
|
||||||
x = mx.random.normal(shape=(1, D)).astype(dtype)
|
x = mx.random.normal(shape=(1, D)).astype(dtype)
|
||||||
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
|
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
|
||||||
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, type="affine")
|
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, quantization_type="affine")
|
||||||
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
|
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
|
||||||
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, type="affine")
|
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, quantization_type="affine")
|
||||||
mx.eval(x, wq1, wq2)
|
mx.eval(x, wq1, wq2)
|
||||||
time_fn(affine_qmv, x, wq1, wq2)
|
time_fn(affine_qmv, x, wq1, wq2)
|
||||||
|
|
||||||
@ -55,15 +55,19 @@ def time_packed_qmv():
|
|||||||
mx.random.seed(3)
|
mx.random.seed(3)
|
||||||
x = mx.random.normal(shape=(1, D)).astype(dtype)
|
x = mx.random.normal(shape=(1, D)).astype(dtype)
|
||||||
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
|
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
|
||||||
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, type="affine-packed")
|
wq1 = mx.quantize(
|
||||||
|
w1, group_size=group_size, bits=bits, quantization_type="affine-packed"
|
||||||
|
)
|
||||||
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
|
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
|
||||||
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, type="affine-packed")
|
wq2 = mx.quantize(
|
||||||
|
w2, group_size=group_size, bits=bits, quantization_type="affine-packed"
|
||||||
|
)
|
||||||
mx.eval(x, wq1, wq2)
|
mx.eval(x, wq1, wq2)
|
||||||
time_fn(affine_packed_qmv, x, wq1, wq2)
|
time_fn(affine_packed_qmv, x, wq1, wq2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
for b in [2, 3, 4, 6, 8]:
|
for b in [2, 4, 8]:
|
||||||
bits = b
|
bits = b
|
||||||
print(f"Bits {bits}:")
|
print(f"Bits {bits}:")
|
||||||
time_qmv()
|
time_qmv()
|
||||||
|
61
mlx/ops.cpp
61
mlx/ops.cpp
@ -79,29 +79,29 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
|||||||
bool transpose,
|
bool transpose,
|
||||||
int group_size,
|
int group_size,
|
||||||
int bits,
|
int bits,
|
||||||
QuantizationType type) {
|
QuantizationType quantization_type) {
|
||||||
// Check if we have biases as expected
|
// Check if we have biases as expected
|
||||||
switch (type) {
|
switch (quantization_type) {
|
||||||
case QuantizationType::Affine:
|
case QuantizationType::Affine:
|
||||||
if (!biases.has_value()) {
|
if (!biases.has_value()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[" << tag
|
msg << "[" << tag
|
||||||
<< "] The biases argument is required for quantization "
|
<< "] The biases argument is required for quantization "
|
||||||
<< "type '" << type << "'";
|
<< "type '" << quantization_type << "'";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case QuantizationType::AffinePacked:
|
case QuantizationType::AffinePacked:
|
||||||
if (biases.has_value()) {
|
if (biases.has_value()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[" << tag << "] Quantization type '" << type
|
msg << "[" << tag << "] Quantization type '" << quantization_type
|
||||||
<< "' does not use "
|
<< "' does not use "
|
||||||
<< "biases but biases were provided";
|
<< "biases but biases were provided";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
if (bits & (bits - 1)) {
|
if (bits & (bits - 1)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[" << tag << "] Quantization type '" << type
|
msg << "[" << tag << "] Quantization type '" << quantization_type
|
||||||
<< "' does not support " << bits << " bits.";
|
<< "' does not support " << bits << " bits.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
@ -135,7 +135,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
|||||||
|
|
||||||
int weight_dims = w.shape(-1) * 32 / bits;
|
int weight_dims = w.shape(-1) * 32 / bits;
|
||||||
int scales_dims = scales.shape(-1) * group_size;
|
int scales_dims = scales.shape(-1) * group_size;
|
||||||
if (type == QuantizationType::AffinePacked) {
|
if (quantization_type == QuantizationType::AffinePacked) {
|
||||||
scales_dims /= 8;
|
scales_dims /= 8;
|
||||||
weight_dims /= 4;
|
weight_dims /= 4;
|
||||||
}
|
}
|
||||||
@ -147,7 +147,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
|||||||
<< "w.shape() == " << w.shape()
|
<< "w.shape() == " << w.shape()
|
||||||
<< " and scales.shape() == " << scales.shape()
|
<< " and scales.shape() == " << scales.shape()
|
||||||
<< " with group_size=" << group_size << ", bits=" << bits
|
<< " with group_size=" << group_size << ", bits=" << bits
|
||||||
<< " and type='" << type << "'";
|
<< " and type='" << quantization_type << "'";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -155,7 +155,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
|||||||
|
|
||||||
// Calculate the expanded w's dims
|
// Calculate the expanded w's dims
|
||||||
int weight_dims_other = w.shape(-2);
|
int weight_dims_other = w.shape(-2);
|
||||||
if (type == QuantizationType::AffinePacked) {
|
if (quantization_type == QuantizationType::AffinePacked) {
|
||||||
weight_dims_other *= 4;
|
weight_dims_other *= 4;
|
||||||
}
|
}
|
||||||
int w_inner_dims = (transpose) ? weight_dims : weight_dims_other;
|
int w_inner_dims = (transpose) ? weight_dims : weight_dims_other;
|
||||||
@ -3708,7 +3708,7 @@ array quantized_matmul(
|
|||||||
bool transpose /* = true */,
|
bool transpose /* = true */,
|
||||||
int group_size /* = 64 */,
|
int group_size /* = 64 */,
|
||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
QuantizationType type /* = QuantizationType::Affine */,
|
QuantizationType quantization_type /* = QuantizationType::Affine */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
// Check and extract the quantized matrix shape against x
|
// Check and extract the quantized matrix shape against x
|
||||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||||
@ -3720,7 +3720,7 @@ array quantized_matmul(
|
|||||||
transpose,
|
transpose,
|
||||||
group_size,
|
group_size,
|
||||||
bits,
|
bits,
|
||||||
type);
|
quantization_type);
|
||||||
|
|
||||||
// QuantizedMatmul handles w.ndim == 2 case.
|
// QuantizedMatmul handles w.ndim == 2 case.
|
||||||
if (x.ndim() > 2 && w.ndim() > 2) {
|
if (x.ndim() > 2 && w.ndim() > 2) {
|
||||||
@ -3779,7 +3779,7 @@ array quantized_matmul(
|
|||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
dtype,
|
dtype,
|
||||||
std::make_shared<QuantizedMatmul>(
|
std::make_shared<QuantizedMatmul>(
|
||||||
to_stream(s), type, group_size, bits, transpose),
|
to_stream(s), quantization_type, group_size, bits, transpose),
|
||||||
std::move(inputs));
|
std::move(inputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3787,16 +3787,16 @@ std::tuple<array, array, std::optional<array>> quantize(
|
|||||||
const array& w,
|
const array& w,
|
||||||
int group_size /* = 64 */,
|
int group_size /* = 64 */,
|
||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
QuantizationType type /* = QuantizationType::Affine */,
|
QuantizationType quantization_type /* = QuantizationType::Affine */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
switch (type) {
|
switch (quantization_type) {
|
||||||
case QuantizationType::Affine:
|
case QuantizationType::Affine:
|
||||||
return fast::affine_quantize(w, group_size, bits, s);
|
return fast::affine_quantize(w, group_size, bits, s);
|
||||||
case QuantizationType::AffinePacked: {
|
case QuantizationType::AffinePacked: {
|
||||||
if (bits & (bits - 1)) {
|
if (bits & (bits - 1)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[quantize] Quantization type '" << type << "' does not support "
|
msg << "[quantize] Quantization type '" << quantization_type
|
||||||
<< bits << " bits.";
|
<< "' does not support " << bits << " bits.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s);
|
auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s);
|
||||||
@ -3822,7 +3822,7 @@ array dequantize(
|
|||||||
const std::optional<array>& biases,
|
const std::optional<array>& biases,
|
||||||
int group_size /* = 64 */,
|
int group_size /* = 64 */,
|
||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
QuantizationType type /* = QuantizationType::Affine */,
|
QuantizationType quantization_type /* = QuantizationType::Affine */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
return fast::affine_dequantize(
|
return fast::affine_dequantize(
|
||||||
w, scales, biases.value(), group_size, bits, s);
|
w, scales, biases.value(), group_size, bits, s);
|
||||||
@ -3838,21 +3838,38 @@ array gather_qmm(
|
|||||||
bool transpose /* = true */,
|
bool transpose /* = true */,
|
||||||
int group_size /* = 64 */,
|
int group_size /* = 64 */,
|
||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
QuantizationType type /* = QuantizationType::Affine */,
|
QuantizationType quantization_type /* = QuantizationType::Affine */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
if (!lhs_indices_ && !rhs_indices_) {
|
if (!lhs_indices_ && !rhs_indices_) {
|
||||||
return quantized_matmul(
|
return quantized_matmul(
|
||||||
x, w, scales, biases, transpose, group_size, bits, type, s);
|
x,
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
transpose,
|
||||||
|
group_size,
|
||||||
|
bits,
|
||||||
|
quantization_type,
|
||||||
|
s);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (type != QuantizationType::Affine) {
|
if (quantization_type != QuantizationType::Affine) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[gather_qmm] Only quantization type '" << type << "' supported";
|
msg << "[gather_qmm] Only quantization type '" << quantization_type
|
||||||
|
<< "' supported";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||||
"gather_qmm", x, w, scales, biases, transpose, group_size, bits, type);
|
"gather_qmm",
|
||||||
|
x,
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
transpose,
|
||||||
|
group_size,
|
||||||
|
bits,
|
||||||
|
quantization_type);
|
||||||
|
|
||||||
// Extract indices and broadcast them
|
// Extract indices and broadcast them
|
||||||
array lhs_indices = indices_or_default(lhs_indices_, x, s);
|
array lhs_indices = indices_or_default(lhs_indices_, x, s);
|
||||||
@ -3874,7 +3891,7 @@ array gather_qmm(
|
|||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
out_type,
|
out_type,
|
||||||
std::make_shared<GatherQMM>(
|
std::make_shared<GatherQMM>(
|
||||||
to_stream(s), type, group_size, bits, transpose),
|
to_stream(s), quantization_type, group_size, bits, transpose),
|
||||||
{astype(x, out_type, s),
|
{astype(x, out_type, s),
|
||||||
w,
|
w,
|
||||||
astype(scales, out_type, s),
|
astype(scales, out_type, s),
|
||||||
|
@ -1286,7 +1286,7 @@ array quantized_matmul(
|
|||||||
bool transpose = true,
|
bool transpose = true,
|
||||||
int group_size = 64,
|
int group_size = 64,
|
||||||
int bits = 4,
|
int bits = 4,
|
||||||
QuantizationType type = QuantizationType::Affine,
|
QuantizationType quantization_type = QuantizationType::Affine,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Quantize a matrix along its last axis */
|
/** Quantize a matrix along its last axis */
|
||||||
@ -1294,7 +1294,7 @@ std::tuple<array, array, std::optional<array>> quantize(
|
|||||||
const array& w,
|
const array& w,
|
||||||
int group_size = 64,
|
int group_size = 64,
|
||||||
int bits = 4,
|
int bits = 4,
|
||||||
QuantizationType type = QuantizationType::Affine,
|
QuantizationType quantization_type = QuantizationType::Affine,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Dequantize a matrix produced by quantize() */
|
/** Dequantize a matrix produced by quantize() */
|
||||||
@ -1304,7 +1304,7 @@ array dequantize(
|
|||||||
const std::optional<array>& biases,
|
const std::optional<array>& biases,
|
||||||
int group_size = 64,
|
int group_size = 64,
|
||||||
int bits = 4,
|
int bits = 4,
|
||||||
QuantizationType type = QuantizationType::Affine,
|
QuantizationType quantization_type = QuantizationType::Affine,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Compute matrix products with matrix-level gather. */
|
/** Compute matrix products with matrix-level gather. */
|
||||||
@ -1318,7 +1318,7 @@ array gather_qmm(
|
|||||||
bool transpose = true,
|
bool transpose = true,
|
||||||
int group_size = 64,
|
int group_size = 64,
|
||||||
int bits = 4,
|
int bits = 4,
|
||||||
QuantizationType type = QuantizationType::Affine,
|
QuantizationType quantization_type = QuantizationType::Affine,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Returns a contraction of a and b over multiple dimensions. */
|
/** Returns a contraction of a and b over multiple dimensions. */
|
||||||
|
@ -4025,7 +4025,7 @@ void init_ops(nb::module_& m) {
|
|||||||
bool transpose,
|
bool transpose,
|
||||||
int group_size,
|
int group_size,
|
||||||
int bits,
|
int bits,
|
||||||
const std::string& type,
|
const std::string& quantization_type,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
return mx::quantized_matmul(
|
return mx::quantized_matmul(
|
||||||
std::move(x),
|
std::move(x),
|
||||||
@ -4035,7 +4035,7 @@ void init_ops(nb::module_& m) {
|
|||||||
transpose,
|
transpose,
|
||||||
group_size,
|
group_size,
|
||||||
bits,
|
bits,
|
||||||
mx::from_string(type),
|
mx::from_string(quantization_type),
|
||||||
s);
|
s);
|
||||||
},
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
@ -4045,11 +4045,11 @@ void init_ops(nb::module_& m) {
|
|||||||
"transpose"_a = true,
|
"transpose"_a = true,
|
||||||
"group_size"_a = 64,
|
"group_size"_a = 64,
|
||||||
"bits"_a = 4,
|
"bits"_a = 4,
|
||||||
"type"_a = "affine",
|
"quantization_type"_a = "affine",
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array], transpose: bool = True, group_size: int = 64, bits: int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
"def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array], transpose: bool = True, group_size: int = 64, bits: int = 4, quantization_type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Perform the matrix multiplication with the quantized matrix ``w``. The
|
Perform the matrix multiplication with the quantized matrix ``w``. The
|
||||||
quantization uses one floating point scale and bias per ``group_size`` of
|
quantization uses one floating point scale and bias per ``group_size`` of
|
||||||
@ -4069,7 +4069,7 @@ void init_ops(nb::module_& m) {
|
|||||||
shares a scale and bias. Default: ``64``.
|
shares a scale and bias. Default: ``64``.
|
||||||
bits (int, optional): The number of bits occupied by each element in
|
bits (int, optional): The number of bits occupied by each element in
|
||||||
``w``. Default: ``4``.
|
``w``. Default: ``4``.
|
||||||
type (str, optional): The type of quantization used for the matrix.
|
quantization_type (str, optional): The type of quantization used for the matrix.
|
||||||
It can be 'affine' or 'affine-packed'.
|
It can be 'affine' or 'affine-packed'.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -4080,18 +4080,19 @@ void init_ops(nb::module_& m) {
|
|||||||
[](const mx::array& w,
|
[](const mx::array& w,
|
||||||
int group_size,
|
int group_size,
|
||||||
int bits,
|
int bits,
|
||||||
const std::string& type,
|
const std::string& quantization_type,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
return mx::quantize(w, group_size, bits, mx::from_string(type), s);
|
return mx::quantize(
|
||||||
|
w, group_size, bits, mx::from_string(quantization_type), s);
|
||||||
},
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
"group_size"_a = 64,
|
"group_size"_a = 64,
|
||||||
"bits"_a = 4,
|
"bits"_a = 4,
|
||||||
"type"_a = "affine",
|
"quantization_type"_a = "affine",
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, Optional[array]]"),
|
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, quantization_type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, Optional[array]]"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Quantize the matrix ``w`` using ``bits`` bits per element.
|
Quantize the matrix ``w`` using ``bits`` bits per element.
|
||||||
|
|
||||||
@ -4133,7 +4134,7 @@ void init_ops(nb::module_& m) {
|
|||||||
scale and bias. Default: ``64``.
|
scale and bias. Default: ``64``.
|
||||||
bits (int, optional): The number of bits occupied by each element of
|
bits (int, optional): The number of bits occupied by each element of
|
||||||
``w`` in the returned quantized matrix. Default: ``4``.
|
``w`` in the returned quantized matrix. Default: ``4``.
|
||||||
type (str, optional): The type of quantization used for the matrix.
|
quantization_type (str, optional): The type of quantization used for the matrix.
|
||||||
It can be 'affine' or 'affine-packed'.
|
It can be 'affine' or 'affine-packed'.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -4152,21 +4153,27 @@ void init_ops(nb::module_& m) {
|
|||||||
const std::optional<mx::array>& biases,
|
const std::optional<mx::array>& biases,
|
||||||
int group_size,
|
int group_size,
|
||||||
int bits,
|
int bits,
|
||||||
const std::string& type,
|
const std::string& quantization_type,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
return mx::dequantize(
|
return mx::dequantize(
|
||||||
wq, scales, biases, group_size, bits, mx::from_string(type), s);
|
wq,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
group_size,
|
||||||
|
bits,
|
||||||
|
mx::from_string(quantization_type),
|
||||||
|
s);
|
||||||
},
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
"scales"_a,
|
"scales"_a,
|
||||||
"biases"_a,
|
"biases"_a,
|
||||||
"group_size"_a = 64,
|
"group_size"_a = 64,
|
||||||
"bits"_a = 4,
|
"bits"_a = 4,
|
||||||
"type"_a = "affine",
|
"quantization_type"_a = "affine",
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def dequantize(w: array, /, scales: array, biases: Optional[array], group_size: int = 64, bits: int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
"def dequantize(w: array, /, scales: array, biases: Optional[array], group_size: int = 64, bits: int = 4, quantization_type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Dequantize the matrix ``w`` using the provided ``scales`` and
|
Dequantize the matrix ``w`` using the provided ``scales`` and
|
||||||
``biases`` and the ``group_size`` and ``bits`` configuration.
|
``biases`` and the ``group_size`` and ``bits`` configuration.
|
||||||
@ -4187,7 +4194,7 @@ void init_ops(nb::module_& m) {
|
|||||||
scale and bias. Default: ``64``.
|
scale and bias. Default: ``64``.
|
||||||
bits (int, optional): The number of bits occupied by each element in
|
bits (int, optional): The number of bits occupied by each element in
|
||||||
``w``. Default: ``4``.
|
``w``. Default: ``4``.
|
||||||
type (str, optional): The type of quantization used for the matrix.
|
quantization_type (str, optional): The type of quantization used for the matrix.
|
||||||
It can be 'affine' or 'affine-packed'.
|
It can be 'affine' or 'affine-packed'.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -4205,7 +4212,7 @@ void init_ops(nb::module_& m) {
|
|||||||
"transpose"_a = true,
|
"transpose"_a = true,
|
||||||
"group_size"_a = 64,
|
"group_size"_a = 64,
|
||||||
"bits"_a = 4,
|
"bits"_a = 4,
|
||||||
"type"_a = "affine",
|
"quantization_type"_a = "affine",
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
@ -4235,7 +4242,7 @@ void init_ops(nb::module_& m) {
|
|||||||
shares a scale and bias. Default: ``64``.
|
shares a scale and bias. Default: ``64``.
|
||||||
bits (int, optional): The number of bits occupied by each element in
|
bits (int, optional): The number of bits occupied by each element in
|
||||||
``w``. Default: ``4``.
|
``w``. Default: ``4``.
|
||||||
type (str, optional): The type of quantization used for the matrix.
|
quantization_type (str, optional): The type of quantization used for the matrix.
|
||||||
It can be 'affine' or 'affine-packed'.
|
It can be 'affine' or 'affine-packed'.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
Loading…
Reference in New Issue
Block a user