saving/loading mixed quantizations

This commit is contained in:
Alex Barron 2024-12-02 10:19:39 -08:00
parent 042280ce50
commit 80e5c37bb9

View File

@ -460,13 +460,17 @@ def load_model(
if (quantization := config.get("quantization", None)) is not None:
# Handle legacy models which may not have everything quantized
def class_predicate(p, m):
# Handle custom per layer quantizations
if p in config["quantization"]:
return config["quantization"][p]
if not hasattr(m, "to_quantized"):
return False
return f"{p}.scales" in weights
nn.quantize(
model,
**quantization,
group_size=quantization["group_size"],
bits=quantization["bits"],
class_predicate=class_predicate,
)
@ -669,7 +673,13 @@ def save_weights(
def quantize_model(
model: nn.Module, config: dict, q_group_size: int, q_bits: int
model: nn.Module,
config: dict,
q_group_size: int,
q_bits: int,
quant_predicate: Optional[
Callable[[str, nn.Module, dict], Union[bool, dict]]
] = None,
) -> Tuple:
"""
Applies quantization to the model weights.
@ -679,13 +689,31 @@ def quantize_model(
config (dict): Model configuration.
q_group_size (int): Group size for quantization.
q_bits (int): Bits per weight for quantization.
quant_predicate (Callable): A callable that decides how
to quantize each layer based on the path.
Accepts the layer `path`, the `module` and the model `config`.
Returns either a bool to signify quantize/no quantize or
a dict of quantization parameters to pass to `to_quantized`.
Returns:
Tuple: Tuple containing quantized weights and config.
"""
quantized_config = copy.deepcopy(config)
nn.quantize(model, q_group_size, q_bits)
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
# Add any custom quantization parameters to the config as we go
def _class_predicate(p, m):
bool_or_params = quant_predicate(p, m, config)
if isinstance(bool_or_params, dict):
quantized_config["quantization"][p] = bool_or_params
return bool_or_params
nn.quantize(
model,
q_group_size,
q_bits,
class_predicate=_class_predicate if quant_predicate else None,
)
# support hf model tree #957
quantized_config["quantization_config"] = quantized_config["quantization"]
quantized_weights = dict(tree_flatten(model.parameters()))
@ -726,6 +754,9 @@ def convert(
upload_repo: str = None,
revision: Optional[str] = None,
dequantize: bool = False,
quant_predicate: Optional[
Callable[[str, nn.Module, dict], Union[bool, dict]]
] = None,
):
# Check the save path is empty
if isinstance(mlx_path, str):
@ -751,7 +782,9 @@ def convert(
if quantize:
print("[INFO] Quantizing")
model.load_weights(list(weights.items()))
weights, config = quantize_model(model, config, q_group_size, q_bits)
weights, config = quantize_model(
model, config, q_group_size, q_bits, quant_predicate=quant_predicate
)
if dequantize:
print("[INFO] Dequantizing")