mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 07:31:26 +08:00
let class predicate specify quantization parameters (#1638)
This commit is contained in:
parent
e4eeb4e910
commit
1445dcaa60
@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
@ -12,7 +12,7 @@ def quantize(
|
|||||||
model: Module,
|
model: Module,
|
||||||
group_size: int = 64,
|
group_size: int = 64,
|
||||||
bits: int = 4,
|
bits: int = 4,
|
||||||
class_predicate: Optional[Callable] = None,
|
class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None,
|
||||||
):
|
):
|
||||||
"""Quantize the sub-modules of a module according to a predicate.
|
"""Quantize the sub-modules of a module according to a predicate.
|
||||||
|
|
||||||
@ -27,17 +27,26 @@ def quantize(
|
|||||||
bits (int): The number of bits per parameter (see
|
bits (int): The number of bits per parameter (see
|
||||||
:func:`mlx.core.quantize`). Default: ``4``.
|
:func:`mlx.core.quantize`). Default: ``4``.
|
||||||
class_predicate (Optional[Callable]): A callable which receives the
|
class_predicate (Optional[Callable]): A callable which receives the
|
||||||
:obj:`Module` path and :obj:`Module` itself and returns ``True`` if
|
:obj:`Module` path and :obj:`Module` itself and returns ``True`` or a
|
||||||
it should be quantized and ``False`` otherwise. If ``None``, then
|
dict of params for `to_quantized` if it should be quantized and
|
||||||
all layers that define a ``to_quantized(group_size, bits)`` method
|
``False`` otherwise. If ``None``, then all layers that define a
|
||||||
are quantized. Default: ``None``.
|
``to_quantized(group_size, bits)`` method are quantized.
|
||||||
|
Default: ``None``.
|
||||||
"""
|
"""
|
||||||
class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized"))
|
class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized"))
|
||||||
|
|
||||||
def _maybe_quantize(path, m):
|
def _maybe_quantize(path, m):
|
||||||
if class_predicate(path, m):
|
if bool_or_params := class_predicate(path, m):
|
||||||
if hasattr(m, "to_quantized"):
|
if hasattr(m, "to_quantized"):
|
||||||
return m.to_quantized(group_size, bits)
|
if isinstance(bool_or_params, bool):
|
||||||
|
return m.to_quantized(group_size=group_size, bits=bits)
|
||||||
|
elif isinstance(bool_or_params, dict):
|
||||||
|
return m.to_quantized(**bool_or_params)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"``class_predicate`` must return a bool"
|
||||||
|
" or a dict of parameters to pass to ``to_quantized``"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unable to quantize model of type {type(m)}")
|
raise ValueError(f"Unable to quantize model of type {type(m)}")
|
||||||
else:
|
else:
|
||||||
@ -197,7 +206,7 @@ class QuantizedLinear(Module):
|
|||||||
out_dims, in_dims = self.weight.shape
|
out_dims, in_dims = self.weight.shape
|
||||||
in_dims *= 32 // self.bits
|
in_dims *= 32 // self.bits
|
||||||
return (
|
return (
|
||||||
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self},"
|
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, "
|
||||||
f"group_size={self.group_size}, bits={self.bits}"
|
f"group_size={self.group_size}, bits={self.bits}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user