diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 613cc5a1e..53729bbf9 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4016,6 +4016,35 @@ array conv_general( {in, wt}); } +std::pair quantization_params_from_mode( + QuantizationMode mode, + std::optional group_size_, + std::optional bits_) { + int default_group_size; + int default_bits; + switch (mode) { + case QuantizationMode::Affine: + default_group_size = 64; + default_bits = 4; + break; + case QuantizationMode::Nvfp4: + default_group_size = 16; + default_bits = 4; + break; + case QuantizationMode::Mxfp4: + default_group_size = 32; + default_bits = 4; + break; + case QuantizationMode::Mxfp8: + default_group_size = 32; + default_bits = 4; + break; + } + return { + group_size_.has_value() ? *group_size_ : default_group_size, + bits_.has_value() ? *bits_ : default_bits}; +} + std::pair validate_mode_with_type( std::string_view tag, const array& scales, @@ -4023,7 +4052,6 @@ std::pair validate_mode_with_type( const std::optional out_type, const std::string& mode) { auto qmode = string_to_quantization_mode(mode, tag); - // TODO add tests for out_type if (out_type.has_value() && !issubdtype(*out_type, floating)) { std::ostringstream msg; msg << "[" << tag << "] Only real floating types are supported but " @@ -4070,16 +4098,19 @@ array quantized_matmul( array scales, std::optional biases /* = std::nullopt */, bool transpose /* = true */, - int group_size /* = 64 */, - int bits /* = 4 */, + std::optional group_size_ /* = std::nullopt */, + std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "affine" */, StreamOrDevice s /* = {} */) { + auto [dtype, qmode] = validate_mode_with_type( + "quantized_matmul", scales, biases, std::nullopt, mode); + + auto [group_size, bits] = + quantization_params_from_mode(qmode, group_size_, bits_); // Check and extract the quantized matrix shape against x auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( "quantized_matmul", x, w, scales, biases, transpose, group_size, bits); - auto [dtype, qmode] = validate_mode_with_type( - "quantized_matmul", scales, biases, std::nullopt, mode); dtype = promote_types(x.dtype(), dtype); if (!issubdtype(dtype, floating)) { @@ -4317,11 +4348,13 @@ std::vector fp_quantize( std::vector quantize( const array& w, - int group_size /* = 64 */, - int bits /* = 4 */, + std::optional group_size_ /* = std::nullopt */, + std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "affine" */, StreamOrDevice s /* = {} */) { auto qmode = string_to_quantization_mode(mode, "quantize"); + auto [group_size, bits] = + quantization_params_from_mode(qmode, group_size_, bits_); if (!issubdtype(w.dtype(), floating)) { std::ostringstream msg; msg << "[quantize] Only real floating types can be quantized " @@ -4563,13 +4596,15 @@ array dequantize( const array& w, const array& scales, const std::optional& biases /* = std::nullopt */, - int group_size /* = 64 */, - int bits /* = 4 */, + std::optional group_size_ /* = std::nullopt */, + std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "affine" */, std::optional dtype /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto [out_type, qmode] = validate_mode_with_type("dequantize", scales, biases, dtype, mode); + auto [group_size, bits] = + quantization_params_from_mode(qmode, group_size_, bits_); if (bits <= 0) { std::ostringstream msg; msg << "[dequantize] Invalid value for bits: " << bits; @@ -4644,21 +4679,22 @@ array gather_qmm( std::optional lhs_indices_ /* = std::nullopt */, std::optional rhs_indices_ /* = std::nullopt */, bool transpose /* = true */, - int group_size /* = 64 */, - int bits /* = 4 */, + std::optional group_size_ /* = std::nullopt */, + std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "affine" */, bool sorted_indices /* = false */, StreamOrDevice s /* = {} */) { if (!lhs_indices_ && !rhs_indices_) { return quantized_matmul( - x, w, scales, biases, transpose, group_size, bits, mode, s); + x, w, scales, biases, transpose, group_size_, bits_, mode, s); } - auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( - "gather_qmm", x, w, scales, biases, transpose, group_size, bits); - auto [out_type, qmode] = validate_mode_with_type("gather_qmm", scales, biases, std::nullopt, mode); + auto [group_size, bits] = + quantization_params_from_mode(qmode, group_size_, bits_); + auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( + "gather_qmm", x, w, scales, biases, transpose, group_size, bits); out_type = promote_types(x.dtype(), out_type); if (!issubdtype(out_type, floating)) { diff --git a/mlx/ops.h b/mlx/ops.h index b86df59fa..49c64e74f 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1379,16 +1379,16 @@ array quantized_matmul( array scales, std::optional biases = std::nullopt, bool transpose = true, - int group_size = 64, - int bits = 4, + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, const std::string& mode = "affine", StreamOrDevice s = {}); /** Quantize a matrix along its last axis */ std::vector quantize( const array& w, - int group_size = 64, - int bits = 4, + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, const std::string& mode = "affine", StreamOrDevice s = {}); @@ -1397,8 +1397,8 @@ array dequantize( const array& w, const array& scales, const std::optional& biases = std::nullopt, - int group_size = 64, - int bits = 4, + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, const std::string& mode = "affine", std::optional dtype = std::nullopt, StreamOrDevice s = {}); @@ -1418,8 +1418,8 @@ array gather_qmm( std::optional lhs_indices = std::nullopt, std::optional rhs_indices = std::nullopt, bool transpose = true, - int group_size = 64, - int bits = 4, + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, const std::string& mode = "affine", bool sorted_indices = false, StreamOrDevice s = {}); diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 52c293e24..d5593b93b 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4194,13 +4194,13 @@ void init_ops(nb::module_& m) { "scales"_a, "biases"_a = nb::none(), "transpose"_a = true, - "group_size"_a = 64, - "bits"_a = 4, + "group_size"_a = nb::none(), + "bits"_a = nb::none(), "mode"_a = "affine", nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), + "def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array] = None, transpose: bool = True, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform the matrix multiplication with the quantized matrix ``w``. The quantization uses one floating point scale and bias per ``group_size`` of @@ -4217,9 +4217,11 @@ void init_ops(nb::module_& m) { transposed ``w`` or not, namely whether we are performing ``x @ w.T`` or ``x @ w``. Default: ``True``. group_size (int, optional): The size of the group in ``w`` that - shares a scale and bias. Default: ``64``. + shares a scale and bias. If unspecified, a default is chosen based + on the mode. Default: ``None``. bits (int, optional): The number of bits occupied by each element in - ``w``. Default: ``4``. + ``w``. If unspecified, a default is chosen based on the mode. + Default: ``None``. mode (str, optional): The quantization mode. Default: ``"affine"``. Returns: @@ -4229,13 +4231,13 @@ void init_ops(nb::module_& m) { "quantize", &mx::quantize, nb::arg(), - "group_size"_a = 64, - "bits"_a = 4, + "group_size"_a = nb::none(), + "bits"_a = nb::none(), "mode"_a = "affine", nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def quantize(w: array, /, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"), + "def quantize(w: array, /, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"), R"pbdoc( Quantize the matrix ``w`` using ``bits`` bits per element. @@ -4268,6 +4270,16 @@ void init_ops(nb::module_& m) { * biases (array): The quantization biases (returned for ``mode=="affine"``). Notes: + + ====== ====================== ========================== ============= ======== + mode group size bits scale type has bias + ====== ====================== ========================== ============= ======== + affine 32, 64 (default), 128 2, 3, 4 (default), 5, 6, 8 same as input yes + mxfp4 32 4 e8m0 no + mxfp8 32 4 e8m0 no + nvfp4 16 4 e4m3 no + ====== ====================== ========================== ============= ======== + The ``"affine"`` mode quantizes groups of :math:`g` consecutive elements in a row of ``w``. For each group the quantized representation of each element :math:`\hat{w_i}` is computed as follows: @@ -4310,14 +4322,14 @@ void init_ops(nb::module_& m) { nb::arg(), "scales"_a, "biases"_a = nb::none(), - "group_size"_a = 64, - "bits"_a = 4, + "group_size"_a = nb::none(), + "bits"_a = nb::none(), "mode"_a = "affine", "dtype"_a = nb::none(), nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', dtype: Optional[Dtype], *, stream: Union[None, Stream, Device] = None) -> array"), + "def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', dtype: Optional[Dtype], *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Dequantize the matrix ``w`` using quantization parameters. @@ -4361,14 +4373,14 @@ void init_ops(nb::module_& m) { "lhs_indices"_a = nb::none(), "rhs_indices"_a = nb::none(), "transpose"_a = true, - "group_size"_a = 64, - "bits"_a = 4, + "group_size"_a = nb::none(), + "bits"_a = nb::none(), "mode"_a = "affine", nb::kw_only(), "sorted_indices"_a = false, "stream"_a = nb::none(), nb::sig( - "def gather_qmm(x: array, w: array, /, scales: array, biases: Optional[array] = None, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"), + "def gather_qmm(x: array, w: array, /, scales: array, biases: Optional[array] = None, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform quantized matrix multiplication with matrix-level gather. diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index ad8dffab6..1c1020c46 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -55,17 +55,17 @@ class TestQuantized(mlx_tests.MLXTestCase): # Invalid bits / group size with self.assertRaises(ValueError): - mx.quantize(w, bits=3, group_size=32, mode="mxfp4") + mx.quantize(w, bits=3, mode="mxfp4") with self.assertRaises(ValueError): - mx.quantize(w, group_size=64, bits=4, mode="mxfp4") + mx.quantize(w, group_size=64, mode="mxfp4") - w_q, scales = mx.quantize(w, group_size=32, bits=4, mode="mxfp4") + w_q, scales = mx.quantize(w, mode="mxfp4") with self.assertRaises(ValueError): - mx.dequantize(w_q, scales, bits=3, group_size=32, mode="mxfp4") + mx.dequantize(w_q, scales, bits=3, mode="mxfp4") with self.assertRaises(ValueError): - mx.dequantize(w_q, scales, group_size=64, bits=4, mode="mxfp4") + mx.dequantize(w_q, scales, group_size=64, mode="mxfp4") # Invalid output type with self.assertRaises(ValueError): @@ -73,13 +73,13 @@ class TestQuantized(mlx_tests.MLXTestCase): w_q, scales, group_size=32, bits=4, mode="mxfp4", dtype=mx.int32 ) - w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4") + w_hat = mx.dequantize(w_q, scales, mode="mxfp4") self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5)) # test quantize/dequantize 0s a = mx.zeros((256, 512)) - w_q, scales = mx.quantize(a, group_size=32, bits=4, mode="mxfp4") - w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4") + w_q, scales = mx.quantize(a, mode="mxfp4") + w_hat = mx.dequantize(w_q, scales, mode="mxfp4") self.assertTrue(mx.all(w_hat == 0)) def test_mxfp8_quantize_dequantize(self): @@ -88,26 +88,26 @@ class TestQuantized(mlx_tests.MLXTestCase): # Invalid bits / group size with self.assertRaises(ValueError): - mx.quantize(w, bits=3, group_size=32, mode="mxfp8") + mx.quantize(w, bits=3, mode="mxfp8") with self.assertRaises(ValueError): mx.quantize(w, group_size=32, bits=7, mode="mxfp8") - w_q, scales = mx.quantize(w, group_size=32, bits=8, mode="mxfp8") + w_q, scales = mx.quantize(w, group_size=32, mode="mxfp8") with self.assertRaises(ValueError): - mx.dequantize(w_q, scales, bits=8, group_size=16, mode="mxfp8") + mx.dequantize(w_q, scales, group_size=16, mode="mxfp8") with self.assertRaises(ValueError): - mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp8") + mx.dequantize(w_q, scales, bits=4, mode="mxfp8") - w_hat = mx.dequantize(w_q, scales, group_size=32, bits=8, mode="mxfp8") + w_hat = mx.dequantize(w_q, scales, mode="mxfp8") self.assertTrue(mx.allclose(w, w_hat, rtol=1e-1, atol=1e-1)) # test quantize/dequantize 0s a = mx.zeros((256, 512)) - w_q, scales = mx.quantize(a, group_size=32, bits=8, mode="mxfp8") - w_hat = mx.dequantize(w_q, scales, group_size=32, bits=8, mode="mxfp8") + w_q, scales = mx.quantize(a, mode="mxfp8") + w_hat = mx.dequantize(w_q, scales, mode="mxfp8") self.assertTrue(mx.all(w_hat == 0)) def test_nvfp4_quantize_dequantize(self): @@ -138,26 +138,26 @@ class TestQuantized(mlx_tests.MLXTestCase): # Invalid bits / group size with self.assertRaises(ValueError): - mx.quantize(w, bits=3, group_size=16, mode="nvfp4") + mx.quantize(w, bits=3, mode="nvfp4") with self.assertRaises(ValueError): - mx.quantize(w, group_size=64, bits=4, mode="nvfp4") + mx.quantize(w, group_size=64, mode="nvfp4") - w_q, scales = mx.quantize(w, group_size=16, bits=4, mode="nvfp4") + w_q, scales = mx.quantize(w, mode="nvfp4") with self.assertRaises(ValueError): - mx.dequantize(w_q, scales, bits=4, group_size=32, mode="nvfp4") + mx.dequantize(w_q, scales, mode="nvfp4") with self.assertRaises(ValueError): - mx.dequantize(w_q, scales, group_size=32, bits=4, mode="nvfp4") + mx.dequantize(w_q, scales, group_size=32, mode="nvfp4") - w_hat = mx.dequantize(w_q, scales, group_size=16, bits=4, mode="nvfp4") + w_hat = mx.dequantize(w_q, scales, mode="nvfp4") self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5)) # test quantize/dequantize 0s a = mx.zeros((256, 512)) - w_q, scales = mx.quantize(a, group_size=16, bits=4, mode="nvfp4") - w_hat = mx.dequantize(w_q, scales, group_size=16, bits=4, mode="nvfp4") + w_q, scales = mx.quantize(a, mode="nvfp4") + w_hat = mx.dequantize(w_q, scales, mode="nvfp4") self.assertTrue(mx.all(w_hat == 0)) def test_qmm(self):