mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	let class predicate specify quantization parameters (#1638)
This commit is contained in:
		| @@ -1,7 +1,7 @@ | ||||
| # Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| import math | ||||
| from typing import Callable, Optional | ||||
| from typing import Callable, Optional, Union | ||||
|  | ||||
| import mlx.core as mx | ||||
| from mlx.nn.layers.base import Module | ||||
| @@ -12,7 +12,7 @@ def quantize( | ||||
|     model: Module, | ||||
|     group_size: int = 64, | ||||
|     bits: int = 4, | ||||
|     class_predicate: Optional[Callable] = None, | ||||
|     class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None, | ||||
| ): | ||||
|     """Quantize the sub-modules of a module according to a predicate. | ||||
|  | ||||
| @@ -27,17 +27,26 @@ def quantize( | ||||
|         bits (int): The number of bits per parameter (see | ||||
|            :func:`mlx.core.quantize`). Default: ``4``. | ||||
|         class_predicate (Optional[Callable]): A callable which receives the | ||||
|           :obj:`Module` path and :obj:`Module` itself and returns ``True`` if | ||||
|           it should be quantized and ``False`` otherwise. If ``None``, then | ||||
|           all layers that define a ``to_quantized(group_size, bits)`` method | ||||
|           are quantized. Default: ``None``. | ||||
|           :obj:`Module` path and :obj:`Module` itself and returns ``True`` or a | ||||
|           dict of params for `to_quantized` if it should be quantized and | ||||
|           ``False`` otherwise. If ``None``, then all layers that define a | ||||
|           ``to_quantized(group_size, bits)`` method are quantized. | ||||
|           Default: ``None``. | ||||
|     """ | ||||
|     class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized")) | ||||
|  | ||||
|     def _maybe_quantize(path, m): | ||||
|         if class_predicate(path, m): | ||||
|         if bool_or_params := class_predicate(path, m): | ||||
|             if hasattr(m, "to_quantized"): | ||||
|                 return m.to_quantized(group_size, bits) | ||||
|                 if isinstance(bool_or_params, bool): | ||||
|                     return m.to_quantized(group_size=group_size, bits=bits) | ||||
|                 elif isinstance(bool_or_params, dict): | ||||
|                     return m.to_quantized(**bool_or_params) | ||||
|                 else: | ||||
|                     raise ValueError( | ||||
|                         "``class_predicate`` must return a bool" | ||||
|                         " or a dict of parameters to pass to ``to_quantized``" | ||||
|                     ) | ||||
|             else: | ||||
|                 raise ValueError(f"Unable to quantize model of type {type(m)}") | ||||
|         else: | ||||
| @@ -197,7 +206,7 @@ class QuantizedLinear(Module): | ||||
|         out_dims, in_dims = self.weight.shape | ||||
|         in_dims *= 32 // self.bits | ||||
|         return ( | ||||
|             f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}," | ||||
|             f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, " | ||||
|             f"group_size={self.group_size}, bits={self.bits}" | ||||
|         ) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Alex Barron
					Alex Barron