Add mode parameter for quantization (#2499)

* add mode parameter for quantization

* mxfp4 quantize/dequantize + start of optional biases

* mxfp4 works

* speedup

* cpu mxfp4

* fix

* fix test tol

* fix

* refactor

* add quant mode enum
This commit is contained in:
Awni Hannun
2025-08-28 06:45:26 -07:00
committed by GitHub
parent 7ef8a6f2d5
commit 70560b6bd5
28 changed files with 3635 additions and 757 deletions

View File

@@ -1322,26 +1322,29 @@ array quantized_matmul(
array x,
array w,
array scales,
array biases,
std::optional<array> biases = std::nullopt,
bool transpose = true,
int group_size = 64,
int bits = 4,
const std::string& mode = "affine",
StreamOrDevice s = {});
/** Quantize a matrix along its last axis */
std::tuple<array, array, array> quantize(
std::vector<array> quantize(
const array& w,
int group_size = 64,
int bits = 4,
const std::string& mode = "affine",
StreamOrDevice s = {});
/** Dequantize a matrix produced by quantize() */
array dequantize(
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases = std::nullopt,
int group_size = 64,
int bits = 4,
const std::string& mode = "affine",
StreamOrDevice s = {});
/** Compute matrix products with matrix-level gather. */
@@ -1349,12 +1352,13 @@ array gather_qmm(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases = std::nullopt,
std::optional<array> lhs_indices = std::nullopt,
std::optional<array> rhs_indices = std::nullopt,
bool transpose = true,
int group_size = 64,
int bits = 4,
const std::string& mode = "affine",
bool sorted_indices = false,
StreamOrDevice s = {});