add mode parameter for quantization

This commit is contained in:
Awni Hannun 2025-08-15 17:36:55 -07:00
parent e843c4d8d5
commit 8ec8d44ee6
9 changed files with 127 additions and 56 deletions

View File

@ -127,7 +127,7 @@ relying on a copy from ``ensure_row_contiguous``:
name="myexp_strided", name="myexp_strided",
input_names=["inp"], input_names=["inp"],
output_names=["out"], output_names=["out"],
source=source source=source,
ensure_row_contiguous=False, ensure_row_contiguous=False,
) )

View File

@ -4029,6 +4029,7 @@ array quantized_matmul(
bool transpose /* = true */, bool transpose /* = true */,
int group_size /* = 64 */, int group_size /* = 64 */,
int bits /* = 4 */, int bits /* = 4 */,
const std::string& mode /* = "affine" */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
// Check and extract the quantized matrix shape against x // Check and extract the quantized matrix shape against x
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
@ -4056,7 +4057,7 @@ array quantized_matmul(
std::move(out_shape), std::move(out_shape),
dtype, dtype,
std::make_shared<QuantizedMatmul>( std::make_shared<QuantizedMatmul>(
to_stream(s), group_size, bits, transpose), to_stream(s), group_size, bits, mode, transpose),
std::move(inputs)); std::move(inputs));
} }
@ -4064,6 +4065,7 @@ std::tuple<array, array, array> quantize(
const array& w, const array& w,
int group_size /* = 64 */, int group_size /* = 64 */,
int bits /* = 4 */, int bits /* = 4 */,
const std::string& mode /* = "affine" */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
return fast::affine_quantize(w, group_size, bits, s); return fast::affine_quantize(w, group_size, bits, s);
} }
@ -4074,6 +4076,7 @@ array dequantize(
const array& biases, const array& biases,
int group_size /* = 64 */, int group_size /* = 64 */,
int bits /* = 4 */, int bits /* = 4 */,
const std::string& mode /* = "affine" */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
return fast::affine_dequantize(w, scales, biases, group_size, bits, s); return fast::affine_dequantize(w, scales, biases, group_size, bits, s);
} }
@ -4088,11 +4091,12 @@ array gather_qmm(
bool transpose /* = true */, bool transpose /* = true */,
int group_size /* = 64 */, int group_size /* = 64 */,
int bits /* = 4 */, int bits /* = 4 */,
const std::string& mode /* = "affine" */,
bool sorted_indices /* = false */, bool sorted_indices /* = false */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
if (!lhs_indices_ && !rhs_indices_) { if (!lhs_indices_ && !rhs_indices_) {
return quantized_matmul( return quantized_matmul(
x, w, scales, biases, transpose, group_size, bits, s); x, w, scales, biases, transpose, group_size, bits, mode, s);
} }
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
@ -4132,6 +4136,7 @@ array gather_qmm(
to_stream(s), to_stream(s),
group_size, group_size,
bits, bits,
mode,
transpose, transpose,
sorted_indices && !rhs_indices_, sorted_indices && !rhs_indices_,
sorted_indices && !lhs_indices_), sorted_indices && !lhs_indices_),

View File

@ -1326,6 +1326,7 @@ array quantized_matmul(
bool transpose = true, bool transpose = true,
int group_size = 64, int group_size = 64,
int bits = 4, int bits = 4,
const std::string& mode = "affine",
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Quantize a matrix along its last axis */ /** Quantize a matrix along its last axis */
@ -1333,6 +1334,7 @@ std::tuple<array, array, array> quantize(
const array& w, const array& w,
int group_size = 64, int group_size = 64,
int bits = 4, int bits = 4,
const std::string& mode = "affine",
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Dequantize a matrix produced by quantize() */ /** Dequantize a matrix produced by quantize() */
@ -1342,6 +1344,7 @@ array dequantize(
const array& biases, const array& biases,
int group_size = 64, int group_size = 64,
int bits = 4, int bits = 4,
const std::string& mode = "affine",
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Compute matrix products with matrix-level gather. */ /** Compute matrix products with matrix-level gather. */
@ -1355,6 +1358,7 @@ array gather_qmm(
bool transpose = true, bool transpose = true,
int group_size = 64, int group_size = 64,
int bits = 4, int bits = 4,
const std::string& mode = "affine",
bool sorted_indices = false, bool sorted_indices = false,
StreamOrDevice s = {}); StreamOrDevice s = {});

View File

@ -3234,6 +3234,7 @@ std::vector<array> QuantizedMatmul::vjp(
!transpose_, !transpose_,
group_size_, group_size_,
bits_, bits_,
mode_,
stream())); stream()));
} }
@ -3262,6 +3263,7 @@ std::vector<array> QuantizedMatmul::vjp(
zeros_like(primals[3], stream()), zeros_like(primals[3], stream()),
group_size_, group_size_,
bits_, bits_,
mode_,
stream()); stream());
wq = unflatten(wq, -1, {-1, group_size_}, stream()); wq = unflatten(wq, -1, {-1, group_size_}, stream());
vjps.push_back(sum(multiply(*dsb, wq, stream()), -1, false, stream())); vjps.push_back(sum(multiply(*dsb, wq, stream()), -1, false, stream()));
@ -3287,13 +3289,14 @@ std::vector<array> QuantizedMatmul::jvp(
transpose_, transpose_,
group_size_, group_size_,
bits_, bits_,
mode_,
stream())}; stream())};
} }
bool QuantizedMatmul::is_equivalent(const Primitive& other) const { bool QuantizedMatmul::is_equivalent(const Primitive& other) const {
const QuantizedMatmul& qm_other = static_cast<const QuantizedMatmul&>(other); const QuantizedMatmul& qm_other = static_cast<const QuantizedMatmul&>(other);
return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ && return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ &&
transpose_ == qm_other.transpose_; mode_ == qm_other.mode_ && transpose_ == qm_other.transpose_;
} }
std::vector<Shape> QuantizedMatmul::output_shapes( std::vector<Shape> QuantizedMatmul::output_shapes(
@ -3348,6 +3351,7 @@ std::vector<array> GatherQMM::vjp(
!transpose_, !transpose_,
group_size_, group_size_,
bits_, bits_,
mode_,
sorted, sorted,
stream()); stream());
if (sorted && no_broadcast) { if (sorted && no_broadcast) {
@ -3406,6 +3410,7 @@ std::vector<array> GatherQMM::vjp(
zeros_like(biases, stream()), zeros_like(biases, stream()),
group_size_, group_size_,
bits_, bits_,
mode_,
stream()), stream()),
-1, -1,
{-1, group_size_}, {-1, group_size_},
@ -3430,7 +3435,7 @@ std::vector<array> GatherQMM::jvp(
bool GatherQMM::is_equivalent(const Primitive& other) const { bool GatherQMM::is_equivalent(const Primitive& other) const {
const GatherQMM& qm_other = static_cast<const GatherQMM&>(other); const GatherQMM& qm_other = static_cast<const GatherQMM&>(other);
return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ && return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ &&
transpose_ == qm_other.transpose_; mode_ == qm_other.mode_ && transpose_ == qm_other.transpose_;
} }
std::pair<std::vector<array>, std::vector<int>> RandomBits::vmap( std::pair<std::vector<array>, std::vector<int>> RandomBits::vmap(

View File

@ -1597,10 +1597,12 @@ class QuantizedMatmul : public UnaryPrimitive {
Stream stream, Stream stream,
int group_size, int group_size,
int bits, int bits,
const std::string& mode,
bool transpose) bool transpose)
: UnaryPrimitive(stream), : UnaryPrimitive(stream),
group_size_(group_size), group_size_(group_size),
bits_(bits), bits_(bits),
mode_(mode),
transpose_(transpose) {} transpose_(transpose) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
@ -1612,12 +1614,13 @@ class QuantizedMatmul : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override; std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
auto state() const { auto state() const {
return std::make_tuple(group_size_, bits_, transpose_); return std::make_tuple(group_size_, bits_, mode_, transpose_);
} }
private: private:
int group_size_; int group_size_;
int bits_; int bits_;
std::string mode_;
bool transpose_; bool transpose_;
}; };
@ -1627,12 +1630,14 @@ class GatherQMM : public UnaryPrimitive {
Stream stream, Stream stream,
int group_size, int group_size,
int bits, int bits,
const std::string& mode,
bool transpose, bool transpose,
bool left_sorted = false, bool left_sorted = false,
bool right_sorted = false) bool right_sorted = false)
: UnaryPrimitive(stream), : UnaryPrimitive(stream),
group_size_(group_size), group_size_(group_size),
bits_(bits), bits_(bits),
mode_(mode),
transpose_(transpose), transpose_(transpose),
left_sorted_(left_sorted), left_sorted_(left_sorted),
right_sorted_(right_sorted) {} right_sorted_(right_sorted) {}
@ -1646,12 +1651,13 @@ class GatherQMM : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
auto state() const { auto state() const {
return std::make_tuple( return std::make_tuple(
group_size_, bits_, transpose_, left_sorted_, right_sorted_); group_size_, bits_, mode_, transpose_, left_sorted_, right_sorted_);
} }
private: private:
int group_size_; int group_size_;
int bits_; int bits_;
std::string mode_;
bool transpose_; bool transpose_;
bool left_sorted_; bool left_sorted_;
bool right_sorted_; bool right_sorted_;

View File

@ -39,6 +39,6 @@ class Embedding(Module):
""" """
return x @ self.weight.T return x @ self.weight.T
def to_quantized(self, group_size: int = 64, bits: int = 4): def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = "affine"):
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer.""" """Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
return QuantizedEmbedding.from_embedding(self, group_size, bits) return QuantizedEmbedding.from_embedding(self, group_size, bits, mode)

View File

@ -70,9 +70,9 @@ class Linear(Module):
x = x @ self["weight"].T x = x @ self["weight"].T
return x return x
def to_quantized(self, group_size: int = 64, bits: int = 4): def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = "affine"):
"""Return a :obj:`QuantizedLinear` layer that approximates this layer.""" """Return a :obj:`QuantizedLinear` layer that approximates this layer."""
return QuantizedLinear.from_linear(self, group_size, bits) return QuantizedLinear.from_linear(self, group_size, bits, mode)
class Bilinear(Module): class Bilinear(Module):

View File

@ -12,6 +12,8 @@ def quantize(
model: Module, model: Module,
group_size: int = 64, group_size: int = 64,
bits: int = 4, bits: int = 4,
*,
mode: str = "affine",
class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None, class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None,
): ):
"""Quantize the sub-modules of a module according to a predicate. """Quantize the sub-modules of a module according to a predicate.
@ -26,6 +28,8 @@ def quantize(
:func:`mlx.core.quantize`). Default: ``64``. :func:`mlx.core.quantize`). Default: ``64``.
bits (int): The number of bits per parameter (see bits (int): The number of bits per parameter (see
:func:`mlx.core.quantize`). Default: ``4``. :func:`mlx.core.quantize`). Default: ``4``.
mode (str): The quantization method to use (see
:func:`mlx.core.quantize`). Default: ``"affine"``.
class_predicate (Optional[Callable]): A callable which receives the class_predicate (Optional[Callable]): A callable which receives the
:obj:`Module` path and :obj:`Module` itself and returns ``True`` or a :obj:`Module` path and :obj:`Module` itself and returns ``True`` or a
dict of params for `to_quantized` if it should be quantized and dict of params for `to_quantized` if it should be quantized and
@ -39,7 +43,7 @@ def quantize(
if bool_or_params := class_predicate(path, m): if bool_or_params := class_predicate(path, m):
if hasattr(m, "to_quantized"): if hasattr(m, "to_quantized"):
if isinstance(bool_or_params, bool): if isinstance(bool_or_params, bool):
return m.to_quantized(group_size=group_size, bits=bits) return m.to_quantized(group_size=group_size, bits=bits, mode=mode)
elif isinstance(bool_or_params, dict): elif isinstance(bool_or_params, dict):
return m.to_quantized(**bool_or_params) return m.to_quantized(**bool_or_params)
else: else:
@ -72,6 +76,8 @@ class QuantizedEmbedding(Module):
weight. See :func:`~mlx.core.quantize`. Default: ``64``. weight. See :func:`~mlx.core.quantize`. Default: ``64``.
bits (int, optional): The bit width to use for the quantized weight. bits (int, optional): The bit width to use for the quantized weight.
See :func:`~mlx.core.quantize`. Default: ``4``. See :func:`~mlx.core.quantize`. Default: ``4``.
mode (str): The quantization method to use (see
:func:`mlx.core.quantize`). Default: ``"affine"``.
""" """
def __init__( def __init__(
@ -80,17 +86,21 @@ class QuantizedEmbedding(Module):
dims: int, dims: int,
group_size: int = 64, group_size: int = 64,
bits: int = 4, bits: int = 4,
mode: str = "affine",
): ):
super().__init__() super().__init__()
# Quantization config # Quantization config
self.group_size = group_size self.group_size = group_size
self.bits = bits self.bits = bits
self.mode = mode
# Initialize the quantized weight # Initialize the quantized weight
scale = math.sqrt(1 / dims) scale = math.sqrt(1 / dims)
weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale) weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale)
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits) self.weight, self.scales, self.biases = mx.quantize(
weight, group_size, bits, mode=mode
)
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
self.dims = dims self.dims = dims
@ -104,6 +114,7 @@ class QuantizedEmbedding(Module):
biases=self["biases"][x], biases=self["biases"][x],
group_size=self.group_size, group_size=self.group_size,
bits=self.bits, bits=self.bits,
mode=self.mode,
) )
def as_linear(self, x): def as_linear(self, x):
@ -121,23 +132,31 @@ class QuantizedEmbedding(Module):
transpose=True, transpose=True,
group_size=self.group_size, group_size=self.group_size,
bits=self.bits, bits=self.bits,
mode=self.mode,
) )
def _extra_repr(self): def _extra_repr(self):
return ( return (
f"{self.num_embeddings}, {self.dims}, " f"{self.num_embeddings}, {self.dims}, "
f"group_size={self.group_size}, bits={self.bits}" f"group_size={self.group_size}, bits={self.bits}, mode={self.mode}"
) )
@classmethod @classmethod
def from_embedding( def from_embedding(
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4 cls,
embedding_layer: Module,
group_size: int = 64,
bits: int = 4,
mode: str = "affine",
): ):
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer.""" """Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
embedding_dims, dims = embedding_layer.weight.shape embedding_dims, dims = embedding_layer.weight.shape
ql = cls(embedding_dims, dims, group_size, bits) ql = cls(embedding_dims, dims, group_size, bits)
ql.weight, ql.scales, ql.biases = mx.quantize( ql.weight, ql.scales, ql.biases = mx.quantize(
embedding_layer.weight, group_size, bits embedding_layer.weight,
group_size,
bits,
mode=mode,
) )
return ql return ql
@ -161,6 +180,8 @@ class QuantizedLinear(Module):
weight. See :func:`~mlx.core.quantize`. Default: ``64``. weight. See :func:`~mlx.core.quantize`. Default: ``64``.
bits (int, optional): The bit width to use for the quantized weight. bits (int, optional): The bit width to use for the quantized weight.
See :func:`~mlx.core.quantize`. Default: ``4``. See :func:`~mlx.core.quantize`. Default: ``4``.
mode (str): The quantization method to use (see
:func:`mlx.core.quantize`). Default: ``"affine"``.
""" """
def __init__( def __init__(
@ -170,12 +191,14 @@ class QuantizedLinear(Module):
bias: bool = True, bias: bool = True,
group_size: int = 64, group_size: int = 64,
bits: int = 4, bits: int = 4,
mode: str = "affine",
): ):
super().__init__() super().__init__()
# Quantization config # Quantization config
self.group_size = group_size self.group_size = group_size
self.bits = bits self.bits = bits
self.mode = mode
# Initialize the quantized weight # Initialize the quantized weight
scale = math.sqrt(1 / input_dims) scale = math.sqrt(1 / input_dims)
@ -184,7 +207,9 @@ class QuantizedLinear(Module):
high=scale, high=scale,
shape=(output_dims, input_dims), shape=(output_dims, input_dims),
) )
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits) self.weight, self.scales, self.biases = mx.quantize(
weight, group_size, bits, mode=mode
)
# And bias if needed # And bias if needed
if bias: if bias:
@ -198,7 +223,7 @@ class QuantizedLinear(Module):
in_dims *= 32 // self.bits in_dims *= 32 // self.bits
return ( return (
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, " f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, "
f"group_size={self.group_size}, bits={self.bits}" f"group_size={self.group_size}, bits={self.bits}, mode={self.mode}"
) )
def __call__(self, x): def __call__(self, x):
@ -210,18 +235,28 @@ class QuantizedLinear(Module):
transpose=True, transpose=True,
group_size=self.group_size, group_size=self.group_size,
bits=self.bits, bits=self.bits,
mode=self.mode,
) )
if "bias" in self: if "bias" in self:
x = x + self["bias"] x = x + self["bias"]
return x return x
@classmethod @classmethod
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4): def from_linear(
cls,
linear_layer: Module,
group_size: int = 64,
bits: int = 4,
mode: str = "affine",
):
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer.""" """Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
output_dims, input_dims = linear_layer.weight.shape output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, group_size, bits) ql = cls(input_dims, output_dims, False, group_size, bits)
ql.weight, ql.scales, ql.biases = mx.quantize( ql.weight, ql.scales, ql.biases = mx.quantize(
linear_layer.weight, group_size, bits linear_layer.weight,
group_size,
bits,
mode=mode,
) )
if "bias" in linear_layer: if "bias" in linear_layer:
ql.bias = linear_layer.bias ql.bias = linear_layer.bias

View File

@ -4153,10 +4153,11 @@ void init_ops(nb::module_& m) {
"transpose"_a = true, "transpose"_a = true,
"group_size"_a = 64, "group_size"_a = 64,
"bits"_a = 4, "bits"_a = 4,
"mode"_a = "affine",
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), "def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Perform the matrix multiplication with the quantized matrix ``w``. The Perform the matrix multiplication with the quantized matrix ``w``. The
quantization uses one floating point scale and bias per ``group_size`` of quantization uses one floating point scale and bias per ``group_size`` of
@ -4175,6 +4176,7 @@ void init_ops(nb::module_& m) {
shares a scale and bias. Default: ``64``. shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``. ``w``. Default: ``4``.
mode (str, optional): The quantization mode. Default: ``"affine"``.
Returns: Returns:
array: The result of the multiplication of ``x`` with ``w``. array: The result of the multiplication of ``x`` with ``w``.
@ -4185,10 +4187,11 @@ void init_ops(nb::module_& m) {
nb::arg(), nb::arg(),
"group_size"_a = 64, "group_size"_a = 64,
"bits"_a = 4, "bits"_a = 4,
"mode"_a = "affine",
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"), "def quantize(w: array, /, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
R"pbdoc( R"pbdoc(
Quantize the matrix ``w`` using ``bits`` bits per element. Quantize the matrix ``w`` using ``bits`` bits per element.
@ -4199,8 +4202,28 @@ void init_ops(nb::module_& m) {
.. warning:: .. warning::
``quantize`` currently only supports 2D inputs with dimensions which are multiples of 32 ``quantize`` currently only supports 2D inputs with the second
dimension divisible by ``group_size``
The supported quantization modes are described in more detail below.
Args:
w (array): Matrix to be quantized
group_size (int, optional): The size of the group in ``w`` that shares a
scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. Default: ``4``.
mode (str, optional): The quantization mode. Default: ``"affine"``.
Returns:
tuple: A tuple containing
* w_q (array): The quantized version of ``w``
* scales (array): The scale to multiply each element with, namely :math:`s`
* biases (array): The biases to add to each element, namely :math:`\beta`
Notes:
The currently supported quantization mode is `"affine"`.
Formally, for a group of :math:`g` consecutive elements :math:`w_1` to Formally, for a group of :math:`g` consecutive elements :math:`w_1` to
:math:`w_g` in a row of ``w`` we compute the quantized representation :math:`w_g` in a row of ``w`` we compute the quantized representation
of each element :math:`\hat{w_i}` as follows of each element :math:`\hat{w_i}` as follows
@ -4223,20 +4246,6 @@ void init_ops(nb::module_& m) {
In order to be able to dequantize the elements of ``w`` we also need to In order to be able to dequantize the elements of ``w`` we also need to
save :math:`s` and :math:`\beta` which are the returned ``scales`` and save :math:`s` and :math:`\beta` which are the returned ``scales`` and
``biases`` respectively. ``biases`` respectively.
Args:
w (array): Matrix to be quantized
group_size (int, optional): The size of the group in ``w`` that shares a
scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. Default: ``4``.
Returns:
tuple: A tuple containing
* w_q (array): The quantized version of ``w``
* scales (array): The scale to multiply each element with, namely :math:`s`
* biases (array): The biases to add to each element, namely :math:`\beta`
)pbdoc"); )pbdoc");
m.def( m.def(
"dequantize", "dequantize",
@ -4246,21 +4255,15 @@ void init_ops(nb::module_& m) {
"biases"_a, "biases"_a,
"group_size"_a = 64, "group_size"_a = 64,
"bits"_a = 4, "bits"_a = 4,
"mode"_a = "affine",
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def dequantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), "def dequantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Dequantize the matrix ``w`` using the provided ``scales`` and Dequantize the matrix ``w`` using quantization parameters.
``biases`` and the ``group_size`` and ``bits`` configuration.
Formally, given the notation in :func:`quantize`, we compute The supported quantization modes are described in more detail below.
:math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and
:math:`\beta` as follows
.. math::
w_i = s \hat{w_i} + \beta
Args: Args:
w (array): Matrix to be quantized w (array): Matrix to be quantized
@ -4270,9 +4273,20 @@ void init_ops(nb::module_& m) {
scale and bias. Default: ``64``. scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``. ``w``. Default: ``4``.
mode (str, optional): The quantization mode. Default: ``"affine"``.
Returns: Returns:
array: The dequantized version of ``w`` array: The dequantized version of ``w``
Notes:
The currently supported quantization mode is `"affine"`.
Formally, given the notation in :func:`quantize`, we compute
:math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and
:math:`\beta` as follows
.. math::
w_i = s \hat{w_i} + \beta
)pbdoc"); )pbdoc");
m.def( m.def(
"gather_qmm", "gather_qmm",
@ -4286,11 +4300,12 @@ void init_ops(nb::module_& m) {
"transpose"_a = true, "transpose"_a = true,
"group_size"_a = 64, "group_size"_a = 64,
"bits"_a = 4, "bits"_a = 4,
"mode"_a = "affine",
nb::kw_only(), nb::kw_only(),
"sorted_indices"_a = false, "sorted_indices"_a = false,
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"), "def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Perform quantized matrix multiplication with matrix-level gather. Perform quantized matrix multiplication with matrix-level gather.
@ -4316,6 +4331,7 @@ void init_ops(nb::module_& m) {
shares a scale and bias. Default: ``64``. shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``. ``w``. Default: ``4``.
mode (str, optional): The quantization mode. Default: ``"affine"``.
sorted_indices (bool, optional): May allow a faster implementation sorted_indices (bool, optional): May allow a faster implementation
if the passed indices are sorted. Default: ``False``. if the passed indices are sorted. Default: ``False``.