let class predicate specify quantization parameters (#1638)

This commit is contained in:
Alex Barron 2024-12-02 14:09:28 -08:00 committed by GitHub
parent e4eeb4e910
commit 1445dcaa60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import math import math
from typing import Callable, Optional from typing import Callable, Optional, Union
import mlx.core as mx import mlx.core as mx
from mlx.nn.layers.base import Module from mlx.nn.layers.base import Module
@ -12,7 +12,7 @@ def quantize(
model: Module, model: Module,
group_size: int = 64, group_size: int = 64,
bits: int = 4, bits: int = 4,
class_predicate: Optional[Callable] = None, class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None,
): ):
"""Quantize the sub-modules of a module according to a predicate. """Quantize the sub-modules of a module according to a predicate.
@ -27,17 +27,26 @@ def quantize(
bits (int): The number of bits per parameter (see bits (int): The number of bits per parameter (see
:func:`mlx.core.quantize`). Default: ``4``. :func:`mlx.core.quantize`). Default: ``4``.
class_predicate (Optional[Callable]): A callable which receives the class_predicate (Optional[Callable]): A callable which receives the
:obj:`Module` path and :obj:`Module` itself and returns ``True`` if :obj:`Module` path and :obj:`Module` itself and returns ``True`` or a
it should be quantized and ``False`` otherwise. If ``None``, then dict of params for `to_quantized` if it should be quantized and
all layers that define a ``to_quantized(group_size, bits)`` method ``False`` otherwise. If ``None``, then all layers that define a
are quantized. Default: ``None``. ``to_quantized(group_size, bits)`` method are quantized.
Default: ``None``.
""" """
class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized")) class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized"))
def _maybe_quantize(path, m): def _maybe_quantize(path, m):
if class_predicate(path, m): if bool_or_params := class_predicate(path, m):
if hasattr(m, "to_quantized"): if hasattr(m, "to_quantized"):
return m.to_quantized(group_size, bits) if isinstance(bool_or_params, bool):
return m.to_quantized(group_size=group_size, bits=bits)
elif isinstance(bool_or_params, dict):
return m.to_quantized(**bool_or_params)
else:
raise ValueError(
"``class_predicate`` must return a bool"
" or a dict of parameters to pass to ``to_quantized``"
)
else: else:
raise ValueError(f"Unable to quantize model of type {type(m)}") raise ValueError(f"Unable to quantize model of type {type(m)}")
else: else:
@ -197,7 +206,7 @@ class QuantizedLinear(Module):
out_dims, in_dims = self.weight.shape out_dims, in_dims = self.weight.shape
in_dims *= 32 // self.bits in_dims *= 32 // self.bits
return ( return (
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}," f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, "
f"group_size={self.group_size}, bits={self.bits}" f"group_size={self.group_size}, bits={self.bits}"
) )