Working packed qmv

This commit is contained in:
Angelos Katharopoulos
2024-12-13 16:26:55 -08:00
parent 11ec07ff9d
commit 651c510940
5 changed files with 217 additions and 7 deletions

View File

@@ -4041,7 +4041,7 @@ void init_ops(nb::module_& m) {
nb::arg(),
nb::arg(),
"scales"_a,
"biases"_a,
"biases"_a = nb::none(),
"transpose"_a = true,
"group_size"_a = 64,
"bits"_a = 4,
@@ -4147,7 +4147,16 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"dequantize",
&mx::dequantize,
[](const mx::array& wq,
const mx::array& scales,
const std::optional<mx::array>& biases,
int group_size,
int bits,
const std::string& type,
mx::StreamOrDevice s) {
return mx::dequantize(
wq, scales, biases, group_size, bits, mx::from_string(type), s);
},
nb::arg(),
"scales"_a,
"biases"_a,