mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-28 03:41:14 +08:00
Initial python binding
This commit is contained in:
parent
bdd68bd893
commit
11ec07ff9d
@ -158,6 +158,17 @@ std::ostream& operator<<(std::ostream& os, QuantizationType type) {
|
|||||||
return os << quantization_type;
|
return os << quantization_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
QuantizationType from_string(const std::string& s) {
|
||||||
|
if (s == "affine") {
|
||||||
|
return QuantizationType::Affine;
|
||||||
|
}
|
||||||
|
if (s == "affine-packed") {
|
||||||
|
return QuantizationType::AffinePacked;
|
||||||
|
}
|
||||||
|
|
||||||
|
throw std::invalid_argument("Cannot map '" + s + "' to a quantization type");
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
inline size_t
|
inline size_t
|
||||||
|
@ -111,6 +111,7 @@ enum class QuantizationType {
|
|||||||
};
|
};
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, QuantizationType type);
|
std::ostream& operator<<(std::ostream& os, QuantizationType type);
|
||||||
|
QuantizationType from_string(const std::string& s);
|
||||||
|
|
||||||
namespace env {
|
namespace env {
|
||||||
|
|
||||||
|
@ -4018,7 +4018,26 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"quantized_matmul",
|
"quantized_matmul",
|
||||||
&mx::quantized_matmul,
|
[](mx::array x,
|
||||||
|
mx::array w,
|
||||||
|
mx::array scales,
|
||||||
|
std::optional<mx::array> biases,
|
||||||
|
bool transpose,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
const std::string& type,
|
||||||
|
mx::StreamOrDevice s) {
|
||||||
|
return mx::quantized_matmul(
|
||||||
|
std::move(x),
|
||||||
|
std::move(w),
|
||||||
|
std::move(scales),
|
||||||
|
std::move(biases),
|
||||||
|
transpose,
|
||||||
|
group_size,
|
||||||
|
bits,
|
||||||
|
mx::from_string(type),
|
||||||
|
s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
"scales"_a,
|
"scales"_a,
|
||||||
@ -4026,10 +4045,11 @@ void init_ops(nb::module_& m) {
|
|||||||
"transpose"_a = true,
|
"transpose"_a = true,
|
||||||
"group_size"_a = 64,
|
"group_size"_a = 64,
|
||||||
"bits"_a = 4,
|
"bits"_a = 4,
|
||||||
|
"type"_a = "affine",
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
"def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array], transpose: bool = True, group_size: int = 64, bits: int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Perform the matrix multiplication with the quantized matrix ``w``. The
|
Perform the matrix multiplication with the quantized matrix ``w``. The
|
||||||
quantization uses one floating point scale and bias per ``group_size`` of
|
quantization uses one floating point scale and bias per ``group_size`` of
|
||||||
@ -4040,7 +4060,8 @@ void init_ops(nb::module_& m) {
|
|||||||
x (array): Input array
|
x (array): Input array
|
||||||
w (array): Quantized matrix packed in unsigned integers
|
w (array): Quantized matrix packed in unsigned integers
|
||||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
biases (array, optional): The biases to use per ``group_size``
|
||||||
|
elements of ``w`` depending on the quantization type
|
||||||
transpose (bool, optional): Defines whether to multiply with the
|
transpose (bool, optional): Defines whether to multiply with the
|
||||||
transposed ``w`` or not, namely whether we are performing
|
transposed ``w`` or not, namely whether we are performing
|
||||||
``x @ w.T`` or ``x @ w``. Default: ``True``.
|
``x @ w.T`` or ``x @ w``. Default: ``True``.
|
||||||
@ -4048,20 +4069,29 @@ void init_ops(nb::module_& m) {
|
|||||||
shares a scale and bias. Default: ``64``.
|
shares a scale and bias. Default: ``64``.
|
||||||
bits (int, optional): The number of bits occupied by each element in
|
bits (int, optional): The number of bits occupied by each element in
|
||||||
``w``. Default: ``4``.
|
``w``. Default: ``4``.
|
||||||
|
type (str, optional): The type of quantization used for the matrix.
|
||||||
|
It can be 'affine' or 'affine-packed'.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The result of the multiplication of ``x`` with ``w``.
|
array: The result of the multiplication of ``x`` with ``w``.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"quantize",
|
"quantize",
|
||||||
&mx::quantize,
|
[](const mx::array& w,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
const std::string& type,
|
||||||
|
mx::StreamOrDevice s) {
|
||||||
|
return mx::quantize(w, group_size, bits, mx::from_string(type), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
"group_size"_a = 64,
|
"group_size"_a = 64,
|
||||||
"bits"_a = 4,
|
"bits"_a = 4,
|
||||||
|
"type"_a = "affine",
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
|
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, Optional[array]]"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Quantize the matrix ``w`` using ``bits`` bits per element.
|
Quantize the matrix ``w`` using ``bits`` bits per element.
|
||||||
|
|
||||||
@ -4103,13 +4133,17 @@ void init_ops(nb::module_& m) {
|
|||||||
scale and bias. Default: ``64``.
|
scale and bias. Default: ``64``.
|
||||||
bits (int, optional): The number of bits occupied by each element of
|
bits (int, optional): The number of bits occupied by each element of
|
||||||
``w`` in the returned quantized matrix. Default: ``4``.
|
``w`` in the returned quantized matrix. Default: ``4``.
|
||||||
|
type (str, optional): The type of quantization used for the matrix.
|
||||||
|
It can be 'affine' or 'affine-packed'.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: A tuple containing
|
tuple: A tuple containing
|
||||||
|
|
||||||
* w_q (array): The quantized version of ``w``
|
* w_q (array): The quantized version of ``w``
|
||||||
* scales (array): The scale to multiply each element with, namely :math:`s`
|
* scales (array): The scale to multiply each element with, namely :math:`s`
|
||||||
* biases (array): The biases to add to each element, namely :math:`\beta`
|
* biases (array, optional): The biases to add to each element, namely
|
||||||
|
* :math:`\beta`. Depending on the quantization type this return value
|
||||||
|
may be None.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"dequantize",
|
"dequantize",
|
||||||
@ -4119,10 +4153,11 @@ void init_ops(nb::module_& m) {
|
|||||||
"biases"_a,
|
"biases"_a,
|
||||||
"group_size"_a = 64,
|
"group_size"_a = 64,
|
||||||
"bits"_a = 4,
|
"bits"_a = 4,
|
||||||
|
"type"_a = "affine",
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def dequantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
"def dequantize(w: array, /, scales: array, biases: Optional[array], group_size: int = 64, bits: int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Dequantize the matrix ``w`` using the provided ``scales`` and
|
Dequantize the matrix ``w`` using the provided ``scales`` and
|
||||||
``biases`` and the ``group_size`` and ``bits`` configuration.
|
``biases`` and the ``group_size`` and ``bits`` configuration.
|
||||||
@ -4143,6 +4178,8 @@ void init_ops(nb::module_& m) {
|
|||||||
scale and bias. Default: ``64``.
|
scale and bias. Default: ``64``.
|
||||||
bits (int, optional): The number of bits occupied by each element in
|
bits (int, optional): The number of bits occupied by each element in
|
||||||
``w``. Default: ``4``.
|
``w``. Default: ``4``.
|
||||||
|
type (str, optional): The type of quantization used for the matrix.
|
||||||
|
It can be 'affine' or 'affine-packed'.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The dequantized version of ``w``
|
array: The dequantized version of ``w``
|
||||||
@ -4159,10 +4196,11 @@ void init_ops(nb::module_& m) {
|
|||||||
"transpose"_a = true,
|
"transpose"_a = true,
|
||||||
"group_size"_a = 64,
|
"group_size"_a = 64,
|
||||||
"bits"_a = 4,
|
"bits"_a = 4,
|
||||||
|
"type"_a = "affine",
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
"def gather_qmm(x: array, w: array, /, scales: array, biases: Optional[array], lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Perform quantized matrix multiplication with matrix-level gather.
|
Perform quantized matrix multiplication with matrix-level gather.
|
||||||
|
|
||||||
@ -4188,6 +4226,8 @@ void init_ops(nb::module_& m) {
|
|||||||
shares a scale and bias. Default: ``64``.
|
shares a scale and bias. Default: ``64``.
|
||||||
bits (int, optional): The number of bits occupied by each element in
|
bits (int, optional): The number of bits occupied by each element in
|
||||||
``w``. Default: ``4``.
|
``w``. Default: ``4``.
|
||||||
|
type (str, optional): The type of quantization used for the matrix.
|
||||||
|
It can be 'affine' or 'affine-packed'.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The result of the multiplication of ``x`` with ``w``
|
array: The result of the multiplication of ``x`` with ``w``
|
||||||
|
Loading…
Reference in New Issue
Block a user