Initial python binding

This commit is contained in:
Angelos Katharopoulos 2024-12-12 11:29:38 -08:00
parent bdd68bd893
commit 11ec07ff9d
3 changed files with 60 additions and 8 deletions

View File

@ -158,6 +158,17 @@ std::ostream& operator<<(std::ostream& os, QuantizationType type) {
return os << quantization_type;
}
QuantizationType from_string(const std::string& s) {
if (s == "affine") {
return QuantizationType::Affine;
}
if (s == "affine-packed") {
return QuantizationType::AffinePacked;
}
throw std::invalid_argument("Cannot map '" + s + "' to a quantization type");
}
namespace {
inline size_t

View File

@ -111,6 +111,7 @@ enum class QuantizationType {
};
std::ostream& operator<<(std::ostream& os, QuantizationType type);
QuantizationType from_string(const std::string& s);
namespace env {

View File

@ -4018,7 +4018,26 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"quantized_matmul",
&mx::quantized_matmul,
[](mx::array x,
mx::array w,
mx::array scales,
std::optional<mx::array> biases,
bool transpose,
int group_size,
int bits,
const std::string& type,
mx::StreamOrDevice s) {
return mx::quantized_matmul(
std::move(x),
std::move(w),
std::move(scales),
std::move(biases),
transpose,
group_size,
bits,
mx::from_string(type),
s);
},
nb::arg(),
nb::arg(),
"scales"_a,
@ -4026,10 +4045,11 @@ void init_ops(nb::module_& m) {
"transpose"_a = true,
"group_size"_a = 64,
"bits"_a = 4,
"type"_a = "affine",
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
"def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array], transpose: bool = True, group_size: int = 64, bits: int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform the matrix multiplication with the quantized matrix ``w``. The
quantization uses one floating point scale and bias per ``group_size`` of
@ -4040,7 +4060,8 @@ void init_ops(nb::module_& m) {
x (array): Input array
w (array): Quantized matrix packed in unsigned integers
scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w``
biases (array, optional): The biases to use per ``group_size``
elements of ``w`` depending on the quantization type
transpose (bool, optional): Defines whether to multiply with the
transposed ``w`` or not, namely whether we are performing
``x @ w.T`` or ``x @ w``. Default: ``True``.
@ -4048,20 +4069,29 @@ void init_ops(nb::module_& m) {
shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
type (str, optional): The type of quantization used for the matrix.
It can be 'affine' or 'affine-packed'.
Returns:
array: The result of the multiplication of ``x`` with ``w``.
)pbdoc");
m.def(
"quantize",
&mx::quantize,
[](const mx::array& w,
int group_size,
int bits,
const std::string& type,
mx::StreamOrDevice s) {
return mx::quantize(w, group_size, bits, mx::from_string(type), s);
},
nb::arg(),
"group_size"_a = 64,
"bits"_a = 4,
"type"_a = "affine",
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, Optional[array]]"),
R"pbdoc(
Quantize the matrix ``w`` using ``bits`` bits per element.
@ -4103,13 +4133,17 @@ void init_ops(nb::module_& m) {
scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. Default: ``4``.
type (str, optional): The type of quantization used for the matrix.
It can be 'affine' or 'affine-packed'.
Returns:
tuple: A tuple containing
* w_q (array): The quantized version of ``w``
* scales (array): The scale to multiply each element with, namely :math:`s`
* biases (array): The biases to add to each element, namely :math:`\beta`
* biases (array, optional): The biases to add to each element, namely
* :math:`\beta`. Depending on the quantization type this return value
may be None.
)pbdoc");
m.def(
"dequantize",
@ -4119,10 +4153,11 @@ void init_ops(nb::module_& m) {
"biases"_a,
"group_size"_a = 64,
"bits"_a = 4,
"type"_a = "affine",
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def dequantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
"def dequantize(w: array, /, scales: array, biases: Optional[array], group_size: int = 64, bits: int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Dequantize the matrix ``w`` using the provided ``scales`` and
``biases`` and the ``group_size`` and ``bits`` configuration.
@ -4143,6 +4178,8 @@ void init_ops(nb::module_& m) {
scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
type (str, optional): The type of quantization used for the matrix.
It can be 'affine' or 'affine-packed'.
Returns:
array: The dequantized version of ``w``
@ -4159,10 +4196,11 @@ void init_ops(nb::module_& m) {
"transpose"_a = true,
"group_size"_a = 64,
"bits"_a = 4,
"type"_a = "affine",
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
"def gather_qmm(x: array, w: array, /, scales: array, biases: Optional[array], lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform quantized matrix multiplication with matrix-level gather.
@ -4188,6 +4226,8 @@ void init_ops(nb::module_& m) {
shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
type (str, optional): The type of quantization used for the matrix.
It can be 'affine' or 'affine-packed'.
Returns:
array: The result of the multiplication of ``x`` with ``w``