mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-27 03:11:16 +08:00
Initial python binding
This commit is contained in:
parent
bdd68bd893
commit
11ec07ff9d
@ -158,6 +158,17 @@ std::ostream& operator<<(std::ostream& os, QuantizationType type) {
|
||||
return os << quantization_type;
|
||||
}
|
||||
|
||||
QuantizationType from_string(const std::string& s) {
|
||||
if (s == "affine") {
|
||||
return QuantizationType::Affine;
|
||||
}
|
||||
if (s == "affine-packed") {
|
||||
return QuantizationType::AffinePacked;
|
||||
}
|
||||
|
||||
throw std::invalid_argument("Cannot map '" + s + "' to a quantization type");
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
inline size_t
|
||||
|
@ -111,6 +111,7 @@ enum class QuantizationType {
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, QuantizationType type);
|
||||
QuantizationType from_string(const std::string& s);
|
||||
|
||||
namespace env {
|
||||
|
||||
|
@ -4018,7 +4018,26 @@ void init_ops(nb::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"quantized_matmul",
|
||||
&mx::quantized_matmul,
|
||||
[](mx::array x,
|
||||
mx::array w,
|
||||
mx::array scales,
|
||||
std::optional<mx::array> biases,
|
||||
bool transpose,
|
||||
int group_size,
|
||||
int bits,
|
||||
const std::string& type,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::quantized_matmul(
|
||||
std::move(x),
|
||||
std::move(w),
|
||||
std::move(scales),
|
||||
std::move(biases),
|
||||
transpose,
|
||||
group_size,
|
||||
bits,
|
||||
mx::from_string(type),
|
||||
s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"scales"_a,
|
||||
@ -4026,10 +4045,11 @@ void init_ops(nb::module_& m) {
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"type"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array], transpose: bool = True, group_size: int = 64, bits: int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform the matrix multiplication with the quantized matrix ``w``. The
|
||||
quantization uses one floating point scale and bias per ``group_size`` of
|
||||
@ -4040,7 +4060,8 @@ void init_ops(nb::module_& m) {
|
||||
x (array): Input array
|
||||
w (array): Quantized matrix packed in unsigned integers
|
||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||
biases (array, optional): The biases to use per ``group_size``
|
||||
elements of ``w`` depending on the quantization type
|
||||
transpose (bool, optional): Defines whether to multiply with the
|
||||
transposed ``w`` or not, namely whether we are performing
|
||||
``x @ w.T`` or ``x @ w``. Default: ``True``.
|
||||
@ -4048,20 +4069,29 @@ void init_ops(nb::module_& m) {
|
||||
shares a scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
type (str, optional): The type of quantization used for the matrix.
|
||||
It can be 'affine' or 'affine-packed'.
|
||||
|
||||
Returns:
|
||||
array: The result of the multiplication of ``x`` with ``w``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"quantize",
|
||||
&mx::quantize,
|
||||
[](const mx::array& w,
|
||||
int group_size,
|
||||
int bits,
|
||||
const std::string& type,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::quantize(w, group_size, bits, mx::from_string(type), s);
|
||||
},
|
||||
nb::arg(),
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"type"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
|
||||
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, Optional[array]]"),
|
||||
R"pbdoc(
|
||||
Quantize the matrix ``w`` using ``bits`` bits per element.
|
||||
|
||||
@ -4103,13 +4133,17 @@ void init_ops(nb::module_& m) {
|
||||
scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element of
|
||||
``w`` in the returned quantized matrix. Default: ``4``.
|
||||
type (str, optional): The type of quantization used for the matrix.
|
||||
It can be 'affine' or 'affine-packed'.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing
|
||||
|
||||
* w_q (array): The quantized version of ``w``
|
||||
* scales (array): The scale to multiply each element with, namely :math:`s`
|
||||
* biases (array): The biases to add to each element, namely :math:`\beta`
|
||||
* biases (array, optional): The biases to add to each element, namely
|
||||
* :math:`\beta`. Depending on the quantization type this return value
|
||||
may be None.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"dequantize",
|
||||
@ -4119,10 +4153,11 @@ void init_ops(nb::module_& m) {
|
||||
"biases"_a,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"type"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def dequantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def dequantize(w: array, /, scales: array, biases: Optional[array], group_size: int = 64, bits: int = 4, type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Dequantize the matrix ``w`` using the provided ``scales`` and
|
||||
``biases`` and the ``group_size`` and ``bits`` configuration.
|
||||
@ -4143,6 +4178,8 @@ void init_ops(nb::module_& m) {
|
||||
scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
type (str, optional): The type of quantization used for the matrix.
|
||||
It can be 'affine' or 'affine-packed'.
|
||||
|
||||
Returns:
|
||||
array: The dequantized version of ``w``
|
||||
@ -4159,10 +4196,11 @@ void init_ops(nb::module_& m) {
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"type"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: Optional[array], lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform quantized matrix multiplication with matrix-level gather.
|
||||
|
||||
@ -4188,6 +4226,8 @@ void init_ops(nb::module_& m) {
|
||||
shares a scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
type (str, optional): The type of quantization used for the matrix.
|
||||
It can be 'affine' or 'affine-packed'.
|
||||
|
||||
Returns:
|
||||
array: The result of the multiplication of ``x`` with ``w``
|
||||
|
Loading…
Reference in New Issue
Block a user