diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 48c2ce13a..aa5996dab 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. import math -from typing import Callable, Optional +from typing import Callable, Optional, Union import mlx.core as mx from mlx.nn.layers.base import Module @@ -12,7 +12,7 @@ def quantize( model: Module, group_size: int = 64, bits: int = 4, - class_predicate: Optional[Callable] = None, + class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None, ): """Quantize the sub-modules of a module according to a predicate. @@ -27,17 +27,26 @@ def quantize( bits (int): The number of bits per parameter (see :func:`mlx.core.quantize`). Default: ``4``. class_predicate (Optional[Callable]): A callable which receives the - :obj:`Module` path and :obj:`Module` itself and returns ``True`` if - it should be quantized and ``False`` otherwise. If ``None``, then - all layers that define a ``to_quantized(group_size, bits)`` method - are quantized. Default: ``None``. + :obj:`Module` path and :obj:`Module` itself and returns ``True`` or a + dict of params for `to_quantized` if it should be quantized and + ``False`` otherwise. If ``None``, then all layers that define a + ``to_quantized(group_size, bits)`` method are quantized. + Default: ``None``. """ class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized")) def _maybe_quantize(path, m): - if class_predicate(path, m): + if bool_or_params := class_predicate(path, m): if hasattr(m, "to_quantized"): - return m.to_quantized(group_size, bits) + if isinstance(bool_or_params, bool): + return m.to_quantized(group_size=group_size, bits=bits) + elif isinstance(bool_or_params, dict): + return m.to_quantized(**bool_or_params) + else: + raise ValueError( + "``class_predicate`` must return a bool" + " or a dict of parameters to pass to ``to_quantized``" + ) else: raise ValueError(f"Unable to quantize model of type {type(m)}") else: @@ -197,7 +206,7 @@ class QuantizedLinear(Module): out_dims, in_dims = self.weight.shape in_dims *= 32 // self.bits return ( - f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}," + f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, " f"group_size={self.group_size}, bits={self.bits}" )