let class predicate specify quantization parameters (#1638)

This commit is contained in:
Alex Barron 2024-12-02 14:09:28 -08:00 committed by GitHub
parent e4eeb4e910
commit 1445dcaa60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
import math
from typing import Callable, Optional
from typing import Callable, Optional, Union
import mlx.core as mx
from mlx.nn.layers.base import Module
@ -12,7 +12,7 @@ def quantize(
model: Module,
group_size: int = 64,
bits: int = 4,
class_predicate: Optional[Callable] = None,
class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None,
):
"""Quantize the sub-modules of a module according to a predicate.
@ -27,17 +27,26 @@ def quantize(
bits (int): The number of bits per parameter (see
:func:`mlx.core.quantize`). Default: ``4``.
class_predicate (Optional[Callable]): A callable which receives the
:obj:`Module` path and :obj:`Module` itself and returns ``True`` if
it should be quantized and ``False`` otherwise. If ``None``, then
all layers that define a ``to_quantized(group_size, bits)`` method
are quantized. Default: ``None``.
:obj:`Module` path and :obj:`Module` itself and returns ``True`` or a
dict of params for `to_quantized` if it should be quantized and
``False`` otherwise. If ``None``, then all layers that define a
``to_quantized(group_size, bits)`` method are quantized.
Default: ``None``.
"""
class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized"))
def _maybe_quantize(path, m):
if class_predicate(path, m):
if bool_or_params := class_predicate(path, m):
if hasattr(m, "to_quantized"):
return m.to_quantized(group_size, bits)
if isinstance(bool_or_params, bool):
return m.to_quantized(group_size=group_size, bits=bits)
elif isinstance(bool_or_params, dict):
return m.to_quantized(**bool_or_params)
else:
raise ValueError(
"``class_predicate`` must return a bool"
" or a dict of parameters to pass to ``to_quantized``"
)
else:
raise ValueError(f"Unable to quantize model of type {type(m)}")
else: