mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 07:31:26 +08:00
let class predicate specify quantization parameters (#1638)
This commit is contained in:
parent
e4eeb4e910
commit
1445dcaa60
@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
@ -12,7 +12,7 @@ def quantize(
|
||||
model: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
class_predicate: Optional[Callable] = None,
|
||||
class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None,
|
||||
):
|
||||
"""Quantize the sub-modules of a module according to a predicate.
|
||||
|
||||
@ -27,17 +27,26 @@ def quantize(
|
||||
bits (int): The number of bits per parameter (see
|
||||
:func:`mlx.core.quantize`). Default: ``4``.
|
||||
class_predicate (Optional[Callable]): A callable which receives the
|
||||
:obj:`Module` path and :obj:`Module` itself and returns ``True`` if
|
||||
it should be quantized and ``False`` otherwise. If ``None``, then
|
||||
all layers that define a ``to_quantized(group_size, bits)`` method
|
||||
are quantized. Default: ``None``.
|
||||
:obj:`Module` path and :obj:`Module` itself and returns ``True`` or a
|
||||
dict of params for `to_quantized` if it should be quantized and
|
||||
``False`` otherwise. If ``None``, then all layers that define a
|
||||
``to_quantized(group_size, bits)`` method are quantized.
|
||||
Default: ``None``.
|
||||
"""
|
||||
class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized"))
|
||||
|
||||
def _maybe_quantize(path, m):
|
||||
if class_predicate(path, m):
|
||||
if bool_or_params := class_predicate(path, m):
|
||||
if hasattr(m, "to_quantized"):
|
||||
return m.to_quantized(group_size, bits)
|
||||
if isinstance(bool_or_params, bool):
|
||||
return m.to_quantized(group_size=group_size, bits=bits)
|
||||
elif isinstance(bool_or_params, dict):
|
||||
return m.to_quantized(**bool_or_params)
|
||||
else:
|
||||
raise ValueError(
|
||||
"``class_predicate`` must return a bool"
|
||||
" or a dict of parameters to pass to ``to_quantized``"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unable to quantize model of type {type(m)}")
|
||||
else:
|
||||
@ -197,7 +206,7 @@ class QuantizedLinear(Module):
|
||||
out_dims, in_dims = self.weight.shape
|
||||
in_dims *= 32 // self.bits
|
||||
return (
|
||||
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self},"
|
||||
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, "
|
||||
f"group_size={self.group_size}, bits={self.bits}"
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user