saving/loading mixed quantizations

This commit is contained in:
Alex Barron 2024-12-02 10:19:39 -08:00
parent 042280ce50
commit 80e5c37bb9

View File

@ -460,13 +460,17 @@ def load_model(
if (quantization := config.get("quantization", None)) is not None: if (quantization := config.get("quantization", None)) is not None:
# Handle legacy models which may not have everything quantized # Handle legacy models which may not have everything quantized
def class_predicate(p, m): def class_predicate(p, m):
# Handle custom per layer quantizations
if p in config["quantization"]:
return config["quantization"][p]
if not hasattr(m, "to_quantized"): if not hasattr(m, "to_quantized"):
return False return False
return f"{p}.scales" in weights return f"{p}.scales" in weights
nn.quantize( nn.quantize(
model, model,
**quantization, group_size=quantization["group_size"],
bits=quantization["bits"],
class_predicate=class_predicate, class_predicate=class_predicate,
) )
@ -669,7 +673,13 @@ def save_weights(
def quantize_model( def quantize_model(
model: nn.Module, config: dict, q_group_size: int, q_bits: int model: nn.Module,
config: dict,
q_group_size: int,
q_bits: int,
quant_predicate: Optional[
Callable[[str, nn.Module, dict], Union[bool, dict]]
] = None,
) -> Tuple: ) -> Tuple:
""" """
Applies quantization to the model weights. Applies quantization to the model weights.
@ -679,13 +689,31 @@ def quantize_model(
config (dict): Model configuration. config (dict): Model configuration.
q_group_size (int): Group size for quantization. q_group_size (int): Group size for quantization.
q_bits (int): Bits per weight for quantization. q_bits (int): Bits per weight for quantization.
quant_predicate (Callable): A callable that decides how
to quantize each layer based on the path.
Accepts the layer `path`, the `module` and the model `config`.
Returns either a bool to signify quantize/no quantize or
a dict of quantization parameters to pass to `to_quantized`.
Returns: Returns:
Tuple: Tuple containing quantized weights and config. Tuple: Tuple containing quantized weights and config.
""" """
quantized_config = copy.deepcopy(config) quantized_config = copy.deepcopy(config)
nn.quantize(model, q_group_size, q_bits)
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits} quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
# Add any custom quantization parameters to the config as we go
def _class_predicate(p, m):
bool_or_params = quant_predicate(p, m, config)
if isinstance(bool_or_params, dict):
quantized_config["quantization"][p] = bool_or_params
return bool_or_params
nn.quantize(
model,
q_group_size,
q_bits,
class_predicate=_class_predicate if quant_predicate else None,
)
# support hf model tree #957 # support hf model tree #957
quantized_config["quantization_config"] = quantized_config["quantization"] quantized_config["quantization_config"] = quantized_config["quantization"]
quantized_weights = dict(tree_flatten(model.parameters())) quantized_weights = dict(tree_flatten(model.parameters()))
@ -726,6 +754,9 @@ def convert(
upload_repo: str = None, upload_repo: str = None,
revision: Optional[str] = None, revision: Optional[str] = None,
dequantize: bool = False, dequantize: bool = False,
quant_predicate: Optional[
Callable[[str, nn.Module, dict], Union[bool, dict]]
] = None,
): ):
# Check the save path is empty # Check the save path is empty
if isinstance(mlx_path, str): if isinstance(mlx_path, str):
@ -751,7 +782,9 @@ def convert(
if quantize: if quantize:
print("[INFO] Quantizing") print("[INFO] Quantizing")
model.load_weights(list(weights.items())) model.load_weights(list(weights.items()))
weights, config = quantize_model(model, config, q_group_size, q_bits) weights, config = quantize_model(
model, config, q_group_size, q_bits, quant_predicate=quant_predicate
)
if dequantize: if dequantize:
print("[INFO] Dequantizing") print("[INFO] Dequantizing")