From b3916cbf2beb9e18535d21d20012756ec671c055 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 20 Dec 2023 16:53:53 -0800 Subject: [PATCH] Improve names of quantization arguments (#235) * Change the default quantization group_size to 64 * Rename groups to group_size and width to bits --- mlx/backend/accelerate/quantized.cpp | 18 +++--- mlx/backend/common/quantized.cpp | 51 ++++++++--------- mlx/backend/metal/kernels/quantized.metal | 68 +++++++++++------------ mlx/backend/metal/quantized.cpp | 4 +- mlx/ops.cpp | 56 ++++++++++--------- mlx/ops.h | 12 ++-- mlx/primitives.cpp | 2 +- mlx/primitives.h | 8 +-- python/mlx/nn/layers/quantized.py | 42 +++++++------- python/src/ops.cpp | 65 +++++++++++----------- python/tests/test_quantized.py | 38 ++++++------- 11 files changed, 184 insertions(+), 180 deletions(-) diff --git a/mlx/backend/accelerate/quantized.cpp b/mlx/backend/accelerate/quantized.cpp index dc545343e..6246e1deb 100644 --- a/mlx/backend/accelerate/quantized.cpp +++ b/mlx/backend/accelerate/quantized.cpp @@ -19,12 +19,12 @@ void _qmm_t_4_64( int M, int N, int K) { - constexpr int width = 4; - constexpr int groups = 64; - constexpr int bitmask = (1 << width) - 1; - constexpr int pack_factor = 32 / width; - constexpr int packs_in_group = groups / pack_factor; - const int Kg = K / groups; + constexpr int bits = 4; + constexpr int group_size = 64; + constexpr int bitmask = (1 << bits) - 1; + constexpr int pack_factor = 32 / bits; + constexpr int packs_in_group = group_size / pack_factor; + const int Kg = K / group_size; const int Kw = K / pack_factor; for (int m = 0; m < M; m++) { @@ -35,7 +35,7 @@ void _qmm_t_4_64( for (int n = 0; n < N; n++) { const simd_float16* x_local = (simd_float16*)x; simd_float16 sum = 0; - for (int k = 0; k < K; k += groups) { + for (int k = 0; k < K; k += group_size) { float scale = *scales_local++; float bias = *biases_local++; @@ -46,7 +46,7 @@ void _qmm_t_4_64( uint32_t wii = *w_local++; for (int p = 0; p < 8; p++) { wi[e * 8 + p] = wii & bitmask; - wii >>= width; + wii >>= bits; } } simd_float16 wf = simd_float(wi); @@ -85,7 +85,7 @@ void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { throw std::runtime_error("x, scales and biases should be row contiguous."); } - if (x.dtype() == float32 && width_ == 4 && groups_ == 64) { + if (x.dtype() == float32 && bits_ == 4 && group_size_ == 64) { out.set_data(allocator::malloc_or_wait(out.nbytes())); int K = x.shape(-1); int M = x.size() / K; diff --git a/mlx/backend/common/quantized.cpp b/mlx/backend/common/quantized.cpp index 2120d881a..1a9b27953 100644 --- a/mlx/backend/common/quantized.cpp +++ b/mlx/backend/common/quantized.cpp @@ -8,7 +8,7 @@ namespace mlx::core { namespace { -template +template void _qmm_t( T* result, const T* x, @@ -18,10 +18,10 @@ void _qmm_t( int M, int N, int K) { - constexpr int bitmask = (1 << width) - 1; - constexpr int pack_factor = 32 / width; - constexpr int packs_in_group = groups / pack_factor; - const int Kg = K / groups; + constexpr int bitmask = (1 << bits) - 1; + constexpr int pack_factor = 32 / bits; + constexpr int packs_in_group = group_size / pack_factor; + const int Kg = K / group_size; const int Kw = K / pack_factor; for (int m = 0; m < M; m++) { @@ -32,7 +32,7 @@ void _qmm_t( for (int n = 0; n < N; n++) { const T* x_local = x; T sum = 0; - for (int k = 0; k < K; k += groups) { + for (int k = 0; k < K; k += group_size) { T scale = *scales_local++; T bias = *biases_local++; @@ -42,7 +42,7 @@ void _qmm_t( #pragma clang loop unroll(full) for (int p = 0; p < pack_factor; p++) { sum += (*x_local++) * (scale * static_cast(wi & bitmask) + bias); - wi >>= width; + wi >>= bits; } } } @@ -64,11 +64,11 @@ void _qmm_t_dispatch_typed( int M, int N, int K, - int width, - int groups) { - switch (width) { + int group_size, + int bits) { + switch (bits) { case 2: { - switch (groups) { + switch (group_size) { case 64: return _qmm_t(result, x, w, scales, biases, M, N, K); case 128: @@ -76,7 +76,7 @@ void _qmm_t_dispatch_typed( } } case 4: { - switch (groups) { + switch (group_size) { case 64: return _qmm_t(result, x, w, scales, biases, M, N, K); case 128: @@ -84,7 +84,7 @@ void _qmm_t_dispatch_typed( } } case 8: { - switch (groups) { + switch (group_size) { case 64: return _qmm_t(result, x, w, scales, biases, M, N, K); case 128: @@ -93,9 +93,10 @@ void _qmm_t_dispatch_typed( } } std::ostringstream msg; - msg << "Quantization type not supported. Provided bit width=" << width - << " and groups=" << groups << ". The supported options are width in " - << "{2, 4, 8} and groups in {64, 128}."; + msg << "Quantization type not supported. Provided bits=" << bits + << " and group_size=" << group_size + << ". The supported options are bits in " + << "{2, 4, 8} and group_size in {64, 128}."; throw std::invalid_argument(msg.str()); } @@ -105,8 +106,8 @@ void _qmm_t_dispatch( const array& w, const array& scales, const array& biases, - int width, - int groups) { + int bits, + int group_size) { int K = x.shape(-1); int M = x.size() / K; int N = w.shape(1); @@ -122,8 +123,8 @@ void _qmm_t_dispatch( M, N, K, - width, - groups); + bits, + group_size); break; case float16: _qmm_t_dispatch_typed( @@ -135,8 +136,8 @@ void _qmm_t_dispatch( M, N, K, - width, - groups); + bits, + group_size); break; case bfloat16: _qmm_t_dispatch_typed( @@ -148,8 +149,8 @@ void _qmm_t_dispatch( M, N, K, - width, - groups); + bits, + group_size); break; default: throw std::invalid_argument( @@ -177,7 +178,7 @@ void QuantizedMatmul::eval(const std::vector& inputs, array& out) { } out.set_data(allocator::malloc_or_wait(out.nbytes())); - _qmm_t_dispatch(out, x, w, scales, biases, width_, groups_); + _qmm_t_dispatch(out, x, w, scales, biases, group_size_, bits_); } } // namespace mlx::core diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index eb48c92f1..8a9e89450 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -14,7 +14,7 @@ using namespace metal; MLX_MTL_CONST int SIMD_SIZE = 32; -template +template [[kernel]] void qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -30,10 +30,10 @@ template (w_local & bitmask) + bias) * x_thread[k]; - w_local >>= width; + w_local >>= bits; } } @@ -104,7 +104,7 @@ template +template [[kernel]] void qmm_t( const device T* x [[buffer(0)]], const device uint32_t* w [[buffer(1)]], @@ -126,10 +126,10 @@ template 0) ? (BK / groups) : 1; + constexpr int groups_per_block = (BK / group_size > 0) ? (BK / group_size) : 1; constexpr int groups_per_simd = BN / (WM * WN); constexpr int w_els_per_thread = (BN * BK / el_per_int) / (SIMD_SIZE * WM * WN); @@ -145,7 +145,7 @@ template (wi & bitmask) + bias; - wi >>= width; + wi >>= bits; } } } @@ -231,9 +231,9 @@ template ( \ +#define instantiate_qmv(name, itype, group_size, bits) \ + template [[host_name("qmv_n_" #name "_gs_" #group_size "_b_" #bits)]] \ + [[kernel]] void qmv( \ const device uint32_t* w [[buffer(0)]], \ const device itype* scales [[buffer(1)]], \ const device itype* biases [[buffer(2)]], \ @@ -246,10 +246,10 @@ template ( \ +#define instantiate_qmm_t(name, itype, group_size, bits) \ + template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits)]] \ + [[kernel]] void qmm_t( \ const device itype* x [[buffer(0)]], \ const device uint32_t* w [[buffer(1)]], \ const device itype* scales [[buffer(2)]], \ @@ -274,10 +274,10 @@ instantiate_qmv_types( 64, 8) uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_lid [[thread_index_in_simdgroup]]); -#define instantiate_qmm_t_types(groups, width) \ - instantiate_qmm_t(float32, float, groups, width) \ - instantiate_qmm_t(float16, half, groups, width) \ - instantiate_qmm_t(bfloat16, bfloat16_t, groups, width) +#define instantiate_qmm_t_types(group_size, bits) \ + instantiate_qmm_t(float32, float, group_size, bits) \ + instantiate_qmm_t(float16, half, group_size, bits) \ + instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits) instantiate_qmm_t_types(128, 2) instantiate_qmm_t_types(128, 4) diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 7d8225797..398bc8ed0 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -58,7 +58,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { if (B == 1) { std::ostringstream kname; kname << "qmv_" << (w_transposed ? "n_" : "t_") << type_to_name(out) - << "_groups_" << groups_ << "_width_" << width_; + << "_gs_" << group_size_ << "_b_" << bits_; // Encode and dispatch kernel auto compute_encoder = d.get_command_encoder(s.index); @@ -87,7 +87,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { else { std::ostringstream kname; kname << "qmm_" << (w_transposed ? "t_" : "n_") << type_to_name(out) - << "_groups_" << groups_ << "_width_" << width_; + << "_gs_" << group_size_ << "_b_" << bits_; // Encode and dispatch kernel auto compute_encoder = d.get_command_encoder(s.index); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 8aae5596d..4a7f39321 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2583,8 +2583,8 @@ array quantized_matmul( const array& w, const array& scales, const array& biases, - int groups /* = 128 */, - int width /* = 4 */, + int group_size /* = 64 */, + int bits /* = 4 */, StreamOrDevice s /* = {} */) { auto x = in_x; @@ -2611,24 +2611,25 @@ array quantized_matmul( x = reshape(x, {-1, x_inner_dims}, s); } - int w_inner_dims = w.shape(0) * (32 / width); + int w_inner_dims = w.shape(0) * (32 / bits); if (w_inner_dims != x_inner_dims) { std::ostringstream msg; msg << "[quantized_matmul] Last dimension of first input with " << "shape (..., " << x_inner_dims << ") does not match the expanded first " << "dimension of the quantized matrix " << w_inner_dims - << ", computed from shape " << w.shape() << " with groups=" << groups - << " and width=" << width; + << ", computed from shape " << w.shape() + << " with group_size=" << group_size << " and bits=" << bits; throw std::invalid_argument(msg.str()); } - int n_groups = x_inner_dims / groups; + int n_groups = x_inner_dims / group_size; if (scales.shape(-1) != n_groups || biases.shape(-1) != n_groups) { std::ostringstream msg; msg << "[quantized_matmul] Scales and biases provided do not match the " - << "quantization arguments (groups=" << groups << ", width=" << width - << "). Expected shapes (" << w.shape(1) << ", " << x_inner_dims / groups + << "quantization arguments (group_size=" << group_size + << ", bits=" << bits << "). Expected shapes (" << w.shape(1) << ", " + << x_inner_dims / group_size << "), but got scales.shape=" << scales.shape() << " and biases.shape=" << biases.shape(); throw std::invalid_argument(msg.str()); @@ -2637,7 +2638,7 @@ array quantized_matmul( auto out = array( {x.shape(0), w.shape(1)}, x.dtype(), - std::make_unique(to_stream(s), groups, width), + std::make_unique(to_stream(s), group_size, bits), {x, w, scales, biases}); // If needed reshape x to the original batch shape @@ -2651,8 +2652,8 @@ array quantized_matmul( std::tuple quantize( const array& w, - int groups /* = 128 */, - int width /* = 4 */, + int group_size /* = 64 */, + int bits /* = 4 */, StreamOrDevice s /* = {} */) { if (w.ndim() != 2) { throw std::invalid_argument("[quantize] Only matrices supported for now"); @@ -2663,23 +2664,24 @@ std::tuple quantize( "[quantize] All dimensions should be divisible by 32 for now"); } - if ((w.shape(-1) % groups) != 0) { + if ((w.shape(-1) % group_size) != 0) { std::ostringstream msg; msg << "[quantize] The last dimension of the matrix needs to be divisible by " - << "the quantization group size " << groups - << ". However the provided matrix" - << " has shape " << w.shape(); + << "the quantization group size " << group_size + << ". However the provided " + << " matrix has shape " << w.shape(); throw std::invalid_argument(msg.str()); } // Compute some constants used for the quantization - int n_bins = (1 << width) - 1; // 2**width - 1 - int el_per_int = 32 / width; - array shifts = power(array(2, uint32), arange(0, 32, width, uint32, s), s); + int n_bins = (1 << bits) - 1; // 2**bits - 1 + int el_per_int = 32 / bits; + array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s); shifts = reshape(shifts, {1, 1, -1}, s); // Compute scales and biases - array packed_w = reshape(w, {w.shape(0), w.shape(1) / groups, groups}, s); + array packed_w = + reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s); array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s); array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s); array delta = divide(subtract(w_max, w_min, s), array(n_bins, w.dtype()), s); @@ -2700,8 +2702,8 @@ array dequantize( const array& w, const array& scales, const array& biases, - int groups /* = 128 */, - int width /* = 4 */, + int group_size /* = 64 */, + int bits /* = 4 */, StreamOrDevice s /* = {} */) { if (w.ndim() != 2 || scales.ndim() != 2 || biases.ndim() != 2) { throw std::invalid_argument("[dequantize] Only matrices supported for now"); @@ -2723,22 +2725,22 @@ array dequantize( } // Compute some constants for the dequantization - int el_per_int = 32 / width; + int el_per_int = 32 / bits; - if (w.shape(1) * el_per_int != scales.shape(1) * groups) { + if (w.shape(1) * el_per_int != scales.shape(1) * group_size) { std::ostringstream msg; msg << "[dequantize] Shape of scales and biases does not match the matrix " << "given the quantization parameters. Provided matrix of shape " << w.shape() << " and scales/biases of shape " << scales.shape() - << " with groups=" << groups << " and width=" << width << "."; + << " with group_size=" << group_size << " and bits=" << bits << "."; throw std::invalid_argument(msg.str()); } // Extract the pieces from the passed quantized matrix std::vector parts; - for (int start = 0; start < 32; start += width) { + for (int start = 0; start < 32; start += bits) { // TODO: Implement bitwise operators for integral types - int shift_left = 32 - (start + width); + int shift_left = 32 - (start + bits); int shift_right = shift_left + start; array p = multiply(w, array(1 << shift_left, uint32), s); p = floor_divide(p, array(1 << shift_right, uint32), s); @@ -2748,7 +2750,7 @@ array dequantize( array w_full = concatenate(parts, -1, s); // Dequantize - w_full = reshape(w_full, {w.shape(0), -1, groups}, s); + w_full = reshape(w_full, {w.shape(0), -1, group_size}, s); w_full = multiply(w_full, expand_dims(scales, -1, s), s); w_full = add(w_full, expand_dims(biases, -1, s), s); w_full = reshape(w_full, {w.shape(0), -1}, s); diff --git a/mlx/ops.h b/mlx/ops.h index 0c2c2916a..fe59d4e49 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1037,15 +1037,15 @@ array quantized_matmul( const array& w, const array& scales, const array& biases, - int groups = 128, - int width = 4, + int group_size = 64, + int bits = 4, StreamOrDevice s = {}); /** Quantize a matrix along its last axis */ std::tuple quantize( const array& w, - int groups = 128, - int width = 4, + int group_size = 64, + int bits = 4, StreamOrDevice s = {}); /** Dequantize a matrix produced by quantize() */ @@ -1053,8 +1053,8 @@ array dequantize( const array& w, const array& scales, const array& biases, - int groups = 128, - int width = 4, + int group_size = 64, + int bits = 4, StreamOrDevice s = {}); } // namespace mlx::core diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index f67340921..3366e463c 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1718,7 +1718,7 @@ array QuantizedMatmul::jvp( bool QuantizedMatmul::is_equivalent(const Primitive& other) const { const QuantizedMatmul& qm_other = static_cast(other); - return groups_ == qm_other.groups_ && width_ == qm_other.width_; + return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_; } std::pair RandomBits::vmap( diff --git a/mlx/primitives.h b/mlx/primitives.h index fbbde2dd1..0cb98c9c7 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1112,8 +1112,8 @@ class Power : public Primitive { class QuantizedMatmul : public Primitive { public: - explicit QuantizedMatmul(Stream stream, int groups, int width) - : Primitive(stream), groups_(groups), width_(width){}; + explicit QuantizedMatmul(Stream stream, int group_size, int bits) + : Primitive(stream), group_size_(group_size), bits_(bits){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1127,8 +1127,8 @@ class QuantizedMatmul : public Primitive { bool is_equivalent(const Primitive& other) const override; private: - int groups_; - int width_; + int group_size_; + int bits_; void eval(const std::vector& inputs, array& out); }; diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index e311049d3..6d2891db7 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -26,12 +26,12 @@ class QuantizedLinear(Module): Args: input_dims (int): The dimensionality of the input features output_dims (int): The dimensionality of the output features - bias (bool): If set to ``False`` then the layer will not use a bias. - (default: True). - groups (int): The group size to use for the quantized weight. See - :func:`~mlx.core.quantize`. (default: 128) - width (int): The bit width to use for the quantized weight. See - :func:`~mlx.core.quantize`. (default: 4) + bias (bool, optional): If set to ``False`` then the layer will not use + a bias. (default: True). + group_size (int, optional): The group size to use for the quantized + weight. See :func:`~mlx.core.quantize`. (default: 64) + bits (int, optional): The bit width to use for the quantized weight. + See :func:`~mlx.core.quantize`. (default: 4) """ def __init__( @@ -39,14 +39,14 @@ class QuantizedLinear(Module): input_dims: int, output_dims: int, bias: bool = True, - groups: int = 64, - width: int = 4, + group_size: int = 64, + bits: int = 4, ): super().__init__() # Quantization config - self.groups = groups - self.width = width + self.group_size = group_size + self.bits = bits # Initialize the quantized weight scale = math.sqrt(1 / input_dims) @@ -55,7 +55,7 @@ class QuantizedLinear(Module): high=scale, shape=(output_dims, input_dims), ) - self.weight, self.scales, self.biases = mx.quantize(weight, groups, width) + self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits) # And bias if needed if bias: @@ -72,10 +72,10 @@ class QuantizedLinear(Module): def _extra_repr(self): out_dims, in_dims = self.weight.shape - in_dims *= 32 // self.width + in_dims *= 32 // self.bits return ( f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}," - f"groups={self.groups}, width={self.width}" + f"group_size={self.group_size}, bits={self.bits}" ) def __call__(self, x): @@ -84,21 +84,21 @@ class QuantizedLinear(Module): self.weight.T, scales=self.scales, biases=self.biases, - groups=self.groups, - width=self.width, + group_size=self.group_size, + bits=self.bits, ) if "bias" in self: x = x + self.bias return x @classmethod - def from_linear(cls, linear_layer: Module, groups: int = 64, width: int = 4): + def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4): """Create a QuantizedLinear layer from the parameters of a provided linear layer.""" output_dims, input_dims = linear_layer.weight.shape - ql = cls(input_dims, output_dims, False, groups, width) + ql = cls(input_dims, output_dims, False, group_size, bits) ql.weight, ql.scales, ql.biases = mx.quantize( - linear_layer.weight, groups, width + linear_layer.weight, group_size, bits ) if "bias" in linear_layer: ql.bias = linear_layer.bias @@ -109,13 +109,13 @@ class QuantizedLinear(Module): def quantize_module( cls, model: Module, - groups: int = 64, - width: int = 4, + group_size: int = 64, + bits: int = 4, linear_class_predicate=lambda m: isinstance(m, Linear), ): def _quantize_if_linear(m): if linear_class_predicate(m): - return cls.from_linear(m, groups, width) + return cls.from_linear(m, group_size, bits) else: return m diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 627dd9a80..f8e4f237a 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3011,26 +3011,27 @@ void init_ops(py::module_& m) { py::pos_only(), "scales"_a, "biases"_a, - "groups"_a = 128, - "width"_a = 4, + "group_size"_a = 64, + "bits"_a = 4, py::kw_only(), "stream"_a = none, R"pbdoc( - quantized_matmul(x: array, w: array, scales: array, biases: array, /, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> array + quantized_matmul(x: array, w: array, scales: array, biases: array, /, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array Perform the matrix multiplication with the quantized matrix ``w``. The - quantization uses one floating point scale and bias per ``groups`` of - elements. Each element in ``w`` takes ``width`` bits and is packed in an + quantization uses one floating point scale and bias per ``group_size`` of + elements. Each element in ``w`` takes ``bits`` bits and is packed in an unsigned 32 bit integer. Args: x (array): Input array w (array): Quantized matrix packed in unsigned integers - scales (array): The scales to use per ``groups`` elements of ``w`` - biases (array): The biases to use per ``groups`` elements of ``w`` - groups (int): The size of the group in ``w`` that shares a scale and - bias. (default: 128) - width (int): The bitwidth of the elements in ``w``. (default: 4) + scales (array): The scales to use per ``group_size`` elements of ``w`` + biases (array): The biases to use per ``group_size`` elements of ``w`` + group_size (int, optional): The size of the group in ``w`` that + shares a scale and bias. (default: 64) + bits (int, optional): The number of bits occupied by each element in + ``w``. (default: 4) Returns: result (array): The result of the multiplication of ``x`` with ``w``. @@ -3040,19 +3041,19 @@ void init_ops(py::module_& m) { &quantize, "w"_a, py::pos_only(), - "groups"_a = 128, - "width"_a = 4, + "group_size"_a = 64, + "bits"_a = 4, py::kw_only(), "stream"_a = none, R"pbdoc( - quantize(w: array, /, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array] + quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array] - Quantize the matrix ``w`` using ``width`` bits per element. + Quantize the matrix ``w`` using ``bits`` bits per element. - Note, every ``groups`` elements in a row of ``w`` are quantized + Note, every ``group_size`` elements in a row of ``w`` are quantized together. Hence, number of columns of ``w`` should be divisible by - ``groups``. In particular, the rows of ``w`` are divided into groups of - size ``groups`` which are quantized together. + ``group_size``. In particular, the rows of ``w`` are divided into groups of + size ``group_size`` which are quantized together. .. warning:: @@ -3083,10 +3084,10 @@ void init_ops(py::module_& m) { Args: w (array): Matrix to be quantized - groups (int, optional): The size of the group in ``w`` that shares a - scale and bias. (default: 128) - width (int, optional): The bitwidth of the elements in ``w``. - (default: 4) + group_size (int, optional): The size of the group in ``w`` that shares a + scale and bias. (default: 64) + bits (int, optional): The number of bits occupied by each element of + ``w`` in the returned quantized matrix. (default: 4) Returns: (tuple): A tuple containing @@ -3102,15 +3103,15 @@ void init_ops(py::module_& m) { py::pos_only(), "scales"_a, "biases"_a, - "groups"_a = 128, - "width"_a = 4, + "group_size"_a = 64, + "bits"_a = 4, py::kw_only(), "stream"_a = none, R"pbdoc( - dequantize(w: array, /, scales: array, biases: array, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> array + dequantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array Dequantize the matrix ``w`` using the provided ``scales`` and - ``biases`` and the ``groups`` and ``width`` configuration. + ``biases`` and the ``group_size`` and ``bits`` configuration. Formally, given the notation in :func:`quantize`, we compute :math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and @@ -3122,14 +3123,14 @@ void init_ops(py::module_& m) { Args: w (array): Matrix to be quantized - scales (array): The scales to use per ``groups`` elements of ``w`` - biases (array): The biases to use per ``groups`` elements of ``w`` - groups (int, optional): The size of the group in ``w`` that shares a - scale and bias. (default: 128) - width (int, optional): The bitwidth of the elements in ``w``. - (default: 4) + scales (array): The scales to use per ``group_size`` elements of ``w`` + biases (array): The biases to use per ``group_size`` elements of ``w`` + group_size (int, optional): The size of the group in ``w`` that shares a + scale and bias. (default: 64) + bits (int, optional): The number of bits occupied by each element in + ``w``. (default: 4) Returns: - result (array): The dequantized version of w + result (array): The dequantized version of ``w`` )pbdoc"); } diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 72af0558c..5fcc882a5 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -18,22 +18,22 @@ class TestQuantized(mlx_tests.MLXTestCase): def test_qmm(self): key = mx.random.key(0) k1, k2 = mx.random.split(key) - for groups in [128, 64]: - for width in [2, 4, 8]: + for group_size in [128, 64]: + for bits in [2, 4, 8]: for M in [8, 32, 33, 64]: for N in [512, 1024]: for K in [512, 1024]: with self.subTest( - shape=(M, N, K), groups=groups, width=width + shape=(M, N, K), group_size=group_size, bits=bits ): x = mx.random.normal(shape=(M, K), key=k1) w = mx.random.normal(shape=(N, K), key=k2) - w_q, scales, biases = mx.quantize(w, groups, width) + w_q, scales, biases = mx.quantize(w, group_size, bits) w_hat = mx.dequantize( - w_q, scales, biases, groups, width + w_q, scales, biases, group_size, bits ) y_q = mx.quantized_matmul( - x, w_q.T, scales, biases, width=width, groups=groups + x, w_q.T, scales, biases, group_size, bits ) y_hat = x @ w_hat.T self.assertEqual(y_q.shape, y_hat.shape) @@ -42,16 +42,14 @@ class TestQuantized(mlx_tests.MLXTestCase): def test_qmm_shapes(self): key = mx.random.key(0) k1, k2 = mx.random.split(key) - groups = 64 - width = 4 + group_size = 64 + bits = 4 w = mx.random.normal(shape=(32, 128), key=k2) - w_q, scales, biases = mx.quantize(w, groups, width) - w_hat = mx.dequantize(w_q, scales, biases, groups, width) + w_q, scales, biases = mx.quantize(w, group_size, bits) + w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) for s in [(3, 128), (2, 1, 7, 128)]: x = mx.random.normal(shape=(3, 128), key=k1) - y_q = mx.quantized_matmul( - x, w_q.T, scales, biases, width=width, groups=groups - ) + y_q = mx.quantized_matmul(x, w_q.T, scales, biases, group_size, bits) y_hat = x @ w_hat.T self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 1e-3) @@ -59,17 +57,19 @@ class TestQuantized(mlx_tests.MLXTestCase): def test_qmv(self): key = mx.random.key(0) k1, k2 = mx.random.split(key) - for groups in [128, 64]: - for width in [2, 4, 8]: + for group_size in [128, 64]: + for bits in [2, 4, 8]: for M in [512, 1024]: for N in [512, 1024]: - with self.subTest(shape=(M, N), groups=groups, width=width): + with self.subTest( + shape=(M, N), group_size=group_size, bits=bits + ): x = mx.random.normal(shape=(1, N), key=k1) w = mx.random.normal(shape=(M, N), key=k2) - w_q, scales, biases = mx.quantize(w, groups, width) - w_hat = mx.dequantize(w_q, scales, biases, groups, width) + w_q, scales, biases = mx.quantize(w, group_size, bits) + w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) y_q = mx.quantized_matmul( - x, w_q.T, scales, biases, width=width, groups=groups + x, w_q.T, scales, biases, group_size, bits ) y_hat = x @ w_hat.T self.assertEqual(y_q.shape, y_hat.shape)