mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 00:36:32 +08:00
add mode parameter for quantization
This commit is contained in:
parent
e843c4d8d5
commit
8ec8d44ee6
@ -127,7 +127,7 @@ relying on a copy from ``ensure_row_contiguous``:
|
||||
name="myexp_strided",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source
|
||||
source=source,
|
||||
ensure_row_contiguous=False,
|
||||
)
|
||||
|
||||
|
@ -4029,6 +4029,7 @@ array quantized_matmul(
|
||||
bool transpose /* = true */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
const std::string& mode /* = "affine" */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// Check and extract the quantized matrix shape against x
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
@ -4056,7 +4057,7 @@ array quantized_matmul(
|
||||
std::move(out_shape),
|
||||
dtype,
|
||||
std::make_shared<QuantizedMatmul>(
|
||||
to_stream(s), group_size, bits, transpose),
|
||||
to_stream(s), group_size, bits, mode, transpose),
|
||||
std::move(inputs));
|
||||
}
|
||||
|
||||
@ -4064,6 +4065,7 @@ std::tuple<array, array, array> quantize(
|
||||
const array& w,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
const std::string& mode /* = "affine" */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return fast::affine_quantize(w, group_size, bits, s);
|
||||
}
|
||||
@ -4074,6 +4076,7 @@ array dequantize(
|
||||
const array& biases,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
const std::string& mode /* = "affine" */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return fast::affine_dequantize(w, scales, biases, group_size, bits, s);
|
||||
}
|
||||
@ -4088,11 +4091,12 @@ array gather_qmm(
|
||||
bool transpose /* = true */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
const std::string& mode /* = "affine" */,
|
||||
bool sorted_indices /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (!lhs_indices_ && !rhs_indices_) {
|
||||
return quantized_matmul(
|
||||
x, w, scales, biases, transpose, group_size, bits, s);
|
||||
x, w, scales, biases, transpose, group_size, bits, mode, s);
|
||||
}
|
||||
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
@ -4132,6 +4136,7 @@ array gather_qmm(
|
||||
to_stream(s),
|
||||
group_size,
|
||||
bits,
|
||||
mode,
|
||||
transpose,
|
||||
sorted_indices && !rhs_indices_,
|
||||
sorted_indices && !lhs_indices_),
|
||||
|
@ -1326,6 +1326,7 @@ array quantized_matmul(
|
||||
bool transpose = true,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
const std::string& mode = "affine",
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Quantize a matrix along its last axis */
|
||||
@ -1333,6 +1334,7 @@ std::tuple<array, array, array> quantize(
|
||||
const array& w,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
const std::string& mode = "affine",
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Dequantize a matrix produced by quantize() */
|
||||
@ -1342,6 +1344,7 @@ array dequantize(
|
||||
const array& biases,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
const std::string& mode = "affine",
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Compute matrix products with matrix-level gather. */
|
||||
@ -1355,6 +1358,7 @@ array gather_qmm(
|
||||
bool transpose = true,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
const std::string& mode = "affine",
|
||||
bool sorted_indices = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
|
@ -3234,6 +3234,7 @@ std::vector<array> QuantizedMatmul::vjp(
|
||||
!transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
mode_,
|
||||
stream()));
|
||||
}
|
||||
|
||||
@ -3262,6 +3263,7 @@ std::vector<array> QuantizedMatmul::vjp(
|
||||
zeros_like(primals[3], stream()),
|
||||
group_size_,
|
||||
bits_,
|
||||
mode_,
|
||||
stream());
|
||||
wq = unflatten(wq, -1, {-1, group_size_}, stream());
|
||||
vjps.push_back(sum(multiply(*dsb, wq, stream()), -1, false, stream()));
|
||||
@ -3287,13 +3289,14 @@ std::vector<array> QuantizedMatmul::jvp(
|
||||
transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
mode_,
|
||||
stream())};
|
||||
}
|
||||
|
||||
bool QuantizedMatmul::is_equivalent(const Primitive& other) const {
|
||||
const QuantizedMatmul& qm_other = static_cast<const QuantizedMatmul&>(other);
|
||||
return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ &&
|
||||
transpose_ == qm_other.transpose_;
|
||||
mode_ == qm_other.mode_ && transpose_ == qm_other.transpose_;
|
||||
}
|
||||
|
||||
std::vector<Shape> QuantizedMatmul::output_shapes(
|
||||
@ -3348,6 +3351,7 @@ std::vector<array> GatherQMM::vjp(
|
||||
!transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
mode_,
|
||||
sorted,
|
||||
stream());
|
||||
if (sorted && no_broadcast) {
|
||||
@ -3406,6 +3410,7 @@ std::vector<array> GatherQMM::vjp(
|
||||
zeros_like(biases, stream()),
|
||||
group_size_,
|
||||
bits_,
|
||||
mode_,
|
||||
stream()),
|
||||
-1,
|
||||
{-1, group_size_},
|
||||
@ -3430,7 +3435,7 @@ std::vector<array> GatherQMM::jvp(
|
||||
bool GatherQMM::is_equivalent(const Primitive& other) const {
|
||||
const GatherQMM& qm_other = static_cast<const GatherQMM&>(other);
|
||||
return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ &&
|
||||
transpose_ == qm_other.transpose_;
|
||||
mode_ == qm_other.mode_ && transpose_ == qm_other.transpose_;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> RandomBits::vmap(
|
||||
|
@ -1597,10 +1597,12 @@ class QuantizedMatmul : public UnaryPrimitive {
|
||||
Stream stream,
|
||||
int group_size,
|
||||
int bits,
|
||||
const std::string& mode,
|
||||
bool transpose)
|
||||
: UnaryPrimitive(stream),
|
||||
group_size_(group_size),
|
||||
bits_(bits),
|
||||
mode_(mode),
|
||||
transpose_(transpose) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1612,12 +1614,13 @@ class QuantizedMatmul : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
auto state() const {
|
||||
return std::make_tuple(group_size_, bits_, transpose_);
|
||||
return std::make_tuple(group_size_, bits_, mode_, transpose_);
|
||||
}
|
||||
|
||||
private:
|
||||
int group_size_;
|
||||
int bits_;
|
||||
std::string mode_;
|
||||
bool transpose_;
|
||||
};
|
||||
|
||||
@ -1627,12 +1630,14 @@ class GatherQMM : public UnaryPrimitive {
|
||||
Stream stream,
|
||||
int group_size,
|
||||
int bits,
|
||||
const std::string& mode,
|
||||
bool transpose,
|
||||
bool left_sorted = false,
|
||||
bool right_sorted = false)
|
||||
: UnaryPrimitive(stream),
|
||||
group_size_(group_size),
|
||||
bits_(bits),
|
||||
mode_(mode),
|
||||
transpose_(transpose),
|
||||
left_sorted_(left_sorted),
|
||||
right_sorted_(right_sorted) {}
|
||||
@ -1646,12 +1651,13 @@ class GatherQMM : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_tuple(
|
||||
group_size_, bits_, transpose_, left_sorted_, right_sorted_);
|
||||
group_size_, bits_, mode_, transpose_, left_sorted_, right_sorted_);
|
||||
}
|
||||
|
||||
private:
|
||||
int group_size_;
|
||||
int bits_;
|
||||
std::string mode_;
|
||||
bool transpose_;
|
||||
bool left_sorted_;
|
||||
bool right_sorted_;
|
||||
|
@ -39,6 +39,6 @@ class Embedding(Module):
|
||||
"""
|
||||
return x @ self.weight.T
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = "affine"):
|
||||
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
|
||||
return QuantizedEmbedding.from_embedding(self, group_size, bits)
|
||||
return QuantizedEmbedding.from_embedding(self, group_size, bits, mode)
|
||||
|
@ -70,9 +70,9 @@ class Linear(Module):
|
||||
x = x @ self["weight"].T
|
||||
return x
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = "affine"):
|
||||
"""Return a :obj:`QuantizedLinear` layer that approximates this layer."""
|
||||
return QuantizedLinear.from_linear(self, group_size, bits)
|
||||
return QuantizedLinear.from_linear(self, group_size, bits, mode)
|
||||
|
||||
|
||||
class Bilinear(Module):
|
||||
|
@ -12,6 +12,8 @@ def quantize(
|
||||
model: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
*,
|
||||
mode: str = "affine",
|
||||
class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None,
|
||||
):
|
||||
"""Quantize the sub-modules of a module according to a predicate.
|
||||
@ -26,6 +28,8 @@ def quantize(
|
||||
:func:`mlx.core.quantize`). Default: ``64``.
|
||||
bits (int): The number of bits per parameter (see
|
||||
:func:`mlx.core.quantize`). Default: ``4``.
|
||||
mode (str): The quantization method to use (see
|
||||
:func:`mlx.core.quantize`). Default: ``"affine"``.
|
||||
class_predicate (Optional[Callable]): A callable which receives the
|
||||
:obj:`Module` path and :obj:`Module` itself and returns ``True`` or a
|
||||
dict of params for `to_quantized` if it should be quantized and
|
||||
@ -39,7 +43,7 @@ def quantize(
|
||||
if bool_or_params := class_predicate(path, m):
|
||||
if hasattr(m, "to_quantized"):
|
||||
if isinstance(bool_or_params, bool):
|
||||
return m.to_quantized(group_size=group_size, bits=bits)
|
||||
return m.to_quantized(group_size=group_size, bits=bits, mode=mode)
|
||||
elif isinstance(bool_or_params, dict):
|
||||
return m.to_quantized(**bool_or_params)
|
||||
else:
|
||||
@ -72,6 +76,8 @@ class QuantizedEmbedding(Module):
|
||||
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
|
||||
bits (int, optional): The bit width to use for the quantized weight.
|
||||
See :func:`~mlx.core.quantize`. Default: ``4``.
|
||||
mode (str): The quantization method to use (see
|
||||
:func:`mlx.core.quantize`). Default: ``"affine"``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -80,17 +86,21 @@ class QuantizedEmbedding(Module):
|
||||
dims: int,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
mode: str = "affine",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Quantization config
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
self.mode = mode
|
||||
|
||||
# Initialize the quantized weight
|
||||
scale = math.sqrt(1 / dims)
|
||||
weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale)
|
||||
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
|
||||
self.weight, self.scales, self.biases = mx.quantize(
|
||||
weight, group_size, bits, mode=mode
|
||||
)
|
||||
self.num_embeddings = num_embeddings
|
||||
self.dims = dims
|
||||
|
||||
@ -104,6 +114,7 @@ class QuantizedEmbedding(Module):
|
||||
biases=self["biases"][x],
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
mode=self.mode,
|
||||
)
|
||||
|
||||
def as_linear(self, x):
|
||||
@ -121,23 +132,31 @@ class QuantizedEmbedding(Module):
|
||||
transpose=True,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
mode=self.mode,
|
||||
)
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"{self.num_embeddings}, {self.dims}, "
|
||||
f"group_size={self.group_size}, bits={self.bits}"
|
||||
f"group_size={self.group_size}, bits={self.bits}, mode={self.mode}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_embedding(
|
||||
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
|
||||
cls,
|
||||
embedding_layer: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
mode: str = "affine",
|
||||
):
|
||||
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
|
||||
embedding_dims, dims = embedding_layer.weight.shape
|
||||
ql = cls(embedding_dims, dims, group_size, bits)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
embedding_layer.weight, group_size, bits
|
||||
embedding_layer.weight,
|
||||
group_size,
|
||||
bits,
|
||||
mode=mode,
|
||||
)
|
||||
return ql
|
||||
|
||||
@ -161,6 +180,8 @@ class QuantizedLinear(Module):
|
||||
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
|
||||
bits (int, optional): The bit width to use for the quantized weight.
|
||||
See :func:`~mlx.core.quantize`. Default: ``4``.
|
||||
mode (str): The quantization method to use (see
|
||||
:func:`mlx.core.quantize`). Default: ``"affine"``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -170,12 +191,14 @@ class QuantizedLinear(Module):
|
||||
bias: bool = True,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
mode: str = "affine",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Quantization config
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
self.mode = mode
|
||||
|
||||
# Initialize the quantized weight
|
||||
scale = math.sqrt(1 / input_dims)
|
||||
@ -184,7 +207,9 @@ class QuantizedLinear(Module):
|
||||
high=scale,
|
||||
shape=(output_dims, input_dims),
|
||||
)
|
||||
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
|
||||
self.weight, self.scales, self.biases = mx.quantize(
|
||||
weight, group_size, bits, mode=mode
|
||||
)
|
||||
|
||||
# And bias if needed
|
||||
if bias:
|
||||
@ -198,7 +223,7 @@ class QuantizedLinear(Module):
|
||||
in_dims *= 32 // self.bits
|
||||
return (
|
||||
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, "
|
||||
f"group_size={self.group_size}, bits={self.bits}"
|
||||
f"group_size={self.group_size}, bits={self.bits}, mode={self.mode}"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
@ -210,18 +235,28 @@ class QuantizedLinear(Module):
|
||||
transpose=True,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
mode=self.mode,
|
||||
)
|
||||
if "bias" in self:
|
||||
x = x + self["bias"]
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
|
||||
def from_linear(
|
||||
cls,
|
||||
linear_layer: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
mode: str = "affine",
|
||||
):
|
||||
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
||||
output_dims, input_dims = linear_layer.weight.shape
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
linear_layer.weight, group_size, bits
|
||||
linear_layer.weight,
|
||||
group_size,
|
||||
bits,
|
||||
mode=mode,
|
||||
)
|
||||
if "bias" in linear_layer:
|
||||
ql.bias = linear_layer.bias
|
||||
|
@ -4153,10 +4153,11 @@ void init_ops(nb::module_& m) {
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"mode"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform the matrix multiplication with the quantized matrix ``w``. The
|
||||
quantization uses one floating point scale and bias per ``group_size`` of
|
||||
@ -4175,6 +4176,7 @@ void init_ops(nb::module_& m) {
|
||||
shares a scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
mode (str, optional): The quantization mode. Default: ``"affine"``.
|
||||
|
||||
Returns:
|
||||
array: The result of the multiplication of ``x`` with ``w``.
|
||||
@ -4185,10 +4187,11 @@ void init_ops(nb::module_& m) {
|
||||
nb::arg(),
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"mode"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
|
||||
"def quantize(w: array, /, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
|
||||
R"pbdoc(
|
||||
Quantize the matrix ``w`` using ``bits`` bits per element.
|
||||
|
||||
@ -4199,30 +4202,10 @@ void init_ops(nb::module_& m) {
|
||||
|
||||
.. warning::
|
||||
|
||||
``quantize`` currently only supports 2D inputs with dimensions which are multiples of 32
|
||||
``quantize`` currently only supports 2D inputs with the second
|
||||
dimension divisible by ``group_size``
|
||||
|
||||
Formally, for a group of :math:`g` consecutive elements :math:`w_1` to
|
||||
:math:`w_g` in a row of ``w`` we compute the quantized representation
|
||||
of each element :math:`\hat{w_i}` as follows
|
||||
|
||||
.. math::
|
||||
|
||||
\begin{aligned}
|
||||
\alpha &= \max_i w_i \\
|
||||
\beta &= \min_i w_i \\
|
||||
s &= \frac{\alpha - \beta}{2^b - 1} \\
|
||||
\hat{w_i} &= \textrm{round}\left( \frac{w_i - \beta}{s}\right).
|
||||
\end{aligned}
|
||||
|
||||
After the above computation, :math:`\hat{w_i}` fits in :math:`b` bits
|
||||
and is packed in an unsigned 32-bit integer from the lower to upper
|
||||
bits. For instance, for 4-bit quantization we fit 8 elements in an
|
||||
unsigned 32 bit integer where the 1st element occupies the 4 least
|
||||
significant bits, the 2nd bits 4-7 etc.
|
||||
|
||||
In order to be able to dequantize the elements of ``w`` we also need to
|
||||
save :math:`s` and :math:`\beta` which are the returned ``scales`` and
|
||||
``biases`` respectively.
|
||||
The supported quantization modes are described in more detail below.
|
||||
|
||||
Args:
|
||||
w (array): Matrix to be quantized
|
||||
@ -4230,6 +4213,7 @@ void init_ops(nb::module_& m) {
|
||||
scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element of
|
||||
``w`` in the returned quantized matrix. Default: ``4``.
|
||||
mode (str, optional): The quantization mode. Default: ``"affine"``.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing
|
||||
@ -4237,6 +4221,31 @@ void init_ops(nb::module_& m) {
|
||||
* w_q (array): The quantized version of ``w``
|
||||
* scales (array): The scale to multiply each element with, namely :math:`s`
|
||||
* biases (array): The biases to add to each element, namely :math:`\beta`
|
||||
|
||||
Notes:
|
||||
The currently supported quantization mode is `"affine"`.
|
||||
Formally, for a group of :math:`g` consecutive elements :math:`w_1` to
|
||||
:math:`w_g` in a row of ``w`` we compute the quantized representation
|
||||
of each element :math:`\hat{w_i}` as follows
|
||||
|
||||
.. math::
|
||||
|
||||
\begin{aligned}
|
||||
\alpha &= \max_i w_i \\
|
||||
\beta &= \min_i w_i \\
|
||||
s &= \frac{\alpha - \beta}{2^b - 1} \\
|
||||
\hat{w_i} &= \textrm{round}\left( \frac{w_i - \beta}{s}\right).
|
||||
\end{aligned}
|
||||
|
||||
After the above computation, :math:`\hat{w_i}` fits in :math:`b` bits
|
||||
and is packed in an unsigned 32-bit integer from the lower to upper
|
||||
bits. For instance, for 4-bit quantization we fit 8 elements in an
|
||||
unsigned 32 bit integer where the 1st element occupies the 4 least
|
||||
significant bits, the 2nd bits 4-7 etc.
|
||||
|
||||
In order to be able to dequantize the elements of ``w`` we also need to
|
||||
save :math:`s` and :math:`\beta` which are the returned ``scales`` and
|
||||
``biases`` respectively.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"dequantize",
|
||||
@ -4246,21 +4255,15 @@ void init_ops(nb::module_& m) {
|
||||
"biases"_a,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"mode"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def dequantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def dequantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Dequantize the matrix ``w`` using the provided ``scales`` and
|
||||
``biases`` and the ``group_size`` and ``bits`` configuration.
|
||||
Dequantize the matrix ``w`` using quantization parameters.
|
||||
|
||||
Formally, given the notation in :func:`quantize`, we compute
|
||||
:math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and
|
||||
:math:`\beta` as follows
|
||||
|
||||
.. math::
|
||||
|
||||
w_i = s \hat{w_i} + \beta
|
||||
The supported quantization modes are described in more detail below.
|
||||
|
||||
Args:
|
||||
w (array): Matrix to be quantized
|
||||
@ -4270,9 +4273,20 @@ void init_ops(nb::module_& m) {
|
||||
scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
mode (str, optional): The quantization mode. Default: ``"affine"``.
|
||||
|
||||
Returns:
|
||||
array: The dequantized version of ``w``
|
||||
|
||||
Notes:
|
||||
The currently supported quantization mode is `"affine"`.
|
||||
Formally, given the notation in :func:`quantize`, we compute
|
||||
:math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and
|
||||
:math:`\beta` as follows
|
||||
|
||||
.. math::
|
||||
|
||||
w_i = s \hat{w_i} + \beta
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"gather_qmm",
|
||||
@ -4286,11 +4300,12 @@ void init_ops(nb::module_& m) {
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"mode"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"sorted_indices"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform quantized matrix multiplication with matrix-level gather.
|
||||
|
||||
@ -4316,6 +4331,7 @@ void init_ops(nb::module_& m) {
|
||||
shares a scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
mode (str, optional): The quantization mode. Default: ``"affine"``.
|
||||
sorted_indices (bool, optional): May allow a faster implementation
|
||||
if the passed indices are sorted. Default: ``False``.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user