mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	let class predicate specify quantization parameters (#1638)
This commit is contained in:
		| @@ -1,7 +1,7 @@ | |||||||
| # Copyright © 2023-2024 Apple Inc. | # Copyright © 2023-2024 Apple Inc. | ||||||
|  |  | ||||||
| import math | import math | ||||||
| from typing import Callable, Optional | from typing import Callable, Optional, Union | ||||||
|  |  | ||||||
| import mlx.core as mx | import mlx.core as mx | ||||||
| from mlx.nn.layers.base import Module | from mlx.nn.layers.base import Module | ||||||
| @@ -12,7 +12,7 @@ def quantize( | |||||||
|     model: Module, |     model: Module, | ||||||
|     group_size: int = 64, |     group_size: int = 64, | ||||||
|     bits: int = 4, |     bits: int = 4, | ||||||
|     class_predicate: Optional[Callable] = None, |     class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None, | ||||||
| ): | ): | ||||||
|     """Quantize the sub-modules of a module according to a predicate. |     """Quantize the sub-modules of a module according to a predicate. | ||||||
|  |  | ||||||
| @@ -27,17 +27,26 @@ def quantize( | |||||||
|         bits (int): The number of bits per parameter (see |         bits (int): The number of bits per parameter (see | ||||||
|            :func:`mlx.core.quantize`). Default: ``4``. |            :func:`mlx.core.quantize`). Default: ``4``. | ||||||
|         class_predicate (Optional[Callable]): A callable which receives the |         class_predicate (Optional[Callable]): A callable which receives the | ||||||
|           :obj:`Module` path and :obj:`Module` itself and returns ``True`` if |           :obj:`Module` path and :obj:`Module` itself and returns ``True`` or a | ||||||
|           it should be quantized and ``False`` otherwise. If ``None``, then |           dict of params for `to_quantized` if it should be quantized and | ||||||
|           all layers that define a ``to_quantized(group_size, bits)`` method |           ``False`` otherwise. If ``None``, then all layers that define a | ||||||
|           are quantized. Default: ``None``. |           ``to_quantized(group_size, bits)`` method are quantized. | ||||||
|  |           Default: ``None``. | ||||||
|     """ |     """ | ||||||
|     class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized")) |     class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized")) | ||||||
|  |  | ||||||
|     def _maybe_quantize(path, m): |     def _maybe_quantize(path, m): | ||||||
|         if class_predicate(path, m): |         if bool_or_params := class_predicate(path, m): | ||||||
|             if hasattr(m, "to_quantized"): |             if hasattr(m, "to_quantized"): | ||||||
|                 return m.to_quantized(group_size, bits) |                 if isinstance(bool_or_params, bool): | ||||||
|  |                     return m.to_quantized(group_size=group_size, bits=bits) | ||||||
|  |                 elif isinstance(bool_or_params, dict): | ||||||
|  |                     return m.to_quantized(**bool_or_params) | ||||||
|  |                 else: | ||||||
|  |                     raise ValueError( | ||||||
|  |                         "``class_predicate`` must return a bool" | ||||||
|  |                         " or a dict of parameters to pass to ``to_quantized``" | ||||||
|  |                     ) | ||||||
|             else: |             else: | ||||||
|                 raise ValueError(f"Unable to quantize model of type {type(m)}") |                 raise ValueError(f"Unable to quantize model of type {type(m)}") | ||||||
|         else: |         else: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Alex Barron
					Alex Barron