mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 23:49:43 +08:00
saving/loading mixed quantizations
This commit is contained in:
parent
042280ce50
commit
80e5c37bb9
@ -460,13 +460,17 @@ def load_model(
|
||||
if (quantization := config.get("quantization", None)) is not None:
|
||||
# Handle legacy models which may not have everything quantized
|
||||
def class_predicate(p, m):
|
||||
# Handle custom per layer quantizations
|
||||
if p in config["quantization"]:
|
||||
return config["quantization"][p]
|
||||
if not hasattr(m, "to_quantized"):
|
||||
return False
|
||||
return f"{p}.scales" in weights
|
||||
|
||||
nn.quantize(
|
||||
model,
|
||||
**quantization,
|
||||
group_size=quantization["group_size"],
|
||||
bits=quantization["bits"],
|
||||
class_predicate=class_predicate,
|
||||
)
|
||||
|
||||
@ -669,7 +673,13 @@ def save_weights(
|
||||
|
||||
|
||||
def quantize_model(
|
||||
model: nn.Module, config: dict, q_group_size: int, q_bits: int
|
||||
model: nn.Module,
|
||||
config: dict,
|
||||
q_group_size: int,
|
||||
q_bits: int,
|
||||
quant_predicate: Optional[
|
||||
Callable[[str, nn.Module, dict], Union[bool, dict]]
|
||||
] = None,
|
||||
) -> Tuple:
|
||||
"""
|
||||
Applies quantization to the model weights.
|
||||
@ -679,13 +689,31 @@ def quantize_model(
|
||||
config (dict): Model configuration.
|
||||
q_group_size (int): Group size for quantization.
|
||||
q_bits (int): Bits per weight for quantization.
|
||||
quant_predicate (Callable): A callable that decides how
|
||||
to quantize each layer based on the path.
|
||||
Accepts the layer `path`, the `module` and the model `config`.
|
||||
Returns either a bool to signify quantize/no quantize or
|
||||
a dict of quantization parameters to pass to `to_quantized`.
|
||||
|
||||
Returns:
|
||||
Tuple: Tuple containing quantized weights and config.
|
||||
"""
|
||||
quantized_config = copy.deepcopy(config)
|
||||
nn.quantize(model, q_group_size, q_bits)
|
||||
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
||||
|
||||
# Add any custom quantization parameters to the config as we go
|
||||
def _class_predicate(p, m):
|
||||
bool_or_params = quant_predicate(p, m, config)
|
||||
if isinstance(bool_or_params, dict):
|
||||
quantized_config["quantization"][p] = bool_or_params
|
||||
return bool_or_params
|
||||
|
||||
nn.quantize(
|
||||
model,
|
||||
q_group_size,
|
||||
q_bits,
|
||||
class_predicate=_class_predicate if quant_predicate else None,
|
||||
)
|
||||
# support hf model tree #957
|
||||
quantized_config["quantization_config"] = quantized_config["quantization"]
|
||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||
@ -726,6 +754,9 @@ def convert(
|
||||
upload_repo: str = None,
|
||||
revision: Optional[str] = None,
|
||||
dequantize: bool = False,
|
||||
quant_predicate: Optional[
|
||||
Callable[[str, nn.Module, dict], Union[bool, dict]]
|
||||
] = None,
|
||||
):
|
||||
# Check the save path is empty
|
||||
if isinstance(mlx_path, str):
|
||||
@ -751,7 +782,9 @@ def convert(
|
||||
if quantize:
|
||||
print("[INFO] Quantizing")
|
||||
model.load_weights(list(weights.items()))
|
||||
weights, config = quantize_model(model, config, q_group_size, q_bits)
|
||||
weights, config = quantize_model(
|
||||
model, config, q_group_size, q_bits, quant_predicate=quant_predicate
|
||||
)
|
||||
|
||||
if dequantize:
|
||||
print("[INFO] Dequantizing")
|
||||
|
Loading…
Reference in New Issue
Block a user