mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Add optional quantization types
This commit is contained in:
parent
845efddc8c
commit
bc08025f41
@ -29,6 +29,12 @@ def configure_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
|
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--q-type",
|
||||||
|
choices=["affine", "affine-packed"],
|
||||||
|
default="affine",
|
||||||
|
help="The type of quantization to apply",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dtype",
|
"--dtype",
|
||||||
help="Type to save the non-quantized parameters.",
|
help="Type to save the non-quantized parameters.",
|
||||||
|
@ -528,6 +528,7 @@ def load_model(
|
|||||||
model,
|
model,
|
||||||
group_size=quantization["group_size"],
|
group_size=quantization["group_size"],
|
||||||
bits=quantization["bits"],
|
bits=quantization["bits"],
|
||||||
|
quantization_type=quantization["quantization_type"],
|
||||||
class_predicate=class_predicate,
|
class_predicate=class_predicate,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -737,6 +738,7 @@ def quantize_model(
|
|||||||
config: dict,
|
config: dict,
|
||||||
q_group_size: int,
|
q_group_size: int,
|
||||||
q_bits: int,
|
q_bits: int,
|
||||||
|
q_type: str,
|
||||||
quant_predicate: Optional[
|
quant_predicate: Optional[
|
||||||
Callable[[str, nn.Module, dict], Union[bool, dict]]
|
Callable[[str, nn.Module, dict], Union[bool, dict]]
|
||||||
] = None,
|
] = None,
|
||||||
@ -749,6 +751,7 @@ def quantize_model(
|
|||||||
config (dict): Model configuration.
|
config (dict): Model configuration.
|
||||||
q_group_size (int): Group size for quantization.
|
q_group_size (int): Group size for quantization.
|
||||||
q_bits (int): Bits per weight for quantization.
|
q_bits (int): Bits per weight for quantization.
|
||||||
|
q_type (str): Quantization type
|
||||||
quant_predicate (Callable): A callable that decides how
|
quant_predicate (Callable): A callable that decides how
|
||||||
to quantize each layer based on the path.
|
to quantize each layer based on the path.
|
||||||
Accepts the layer `path`, the `module` and the model `config`.
|
Accepts the layer `path`, the `module` and the model `config`.
|
||||||
@ -759,11 +762,25 @@ def quantize_model(
|
|||||||
Tuple: Tuple containing quantized weights and config.
|
Tuple: Tuple containing quantized weights and config.
|
||||||
"""
|
"""
|
||||||
quantized_config = copy.deepcopy(config)
|
quantized_config = copy.deepcopy(config)
|
||||||
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
quantized_config["quantization"] = {
|
||||||
|
"group_size": q_group_size,
|
||||||
|
"bits": q_bits,
|
||||||
|
"quantization_type": q_type,
|
||||||
|
}
|
||||||
|
|
||||||
# Add any custom quantization parameters to the config as we go
|
# Add any custom quantization parameters to the config as we go
|
||||||
def _class_predicate(p, m):
|
def _class_predicate(p, m):
|
||||||
bool_or_params = quant_predicate(p, m, config)
|
if quant_predicate:
|
||||||
|
bool_or_params = quant_predicate(p, m, config)
|
||||||
|
else:
|
||||||
|
if isinstance(m, nn.Embedding):
|
||||||
|
bool_or_params = {
|
||||||
|
"group_size": q_group_size,
|
||||||
|
"bits": q_bits,
|
||||||
|
"quantization_type": "affine",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
bool_or_params = hasattr(m, "to_quantized")
|
||||||
quantized_config["quantization"][p] = bool_or_params
|
quantized_config["quantization"][p] = bool_or_params
|
||||||
return bool_or_params
|
return bool_or_params
|
||||||
|
|
||||||
@ -771,7 +788,8 @@ def quantize_model(
|
|||||||
model,
|
model,
|
||||||
q_group_size,
|
q_group_size,
|
||||||
q_bits,
|
q_bits,
|
||||||
class_predicate=_class_predicate if quant_predicate else None,
|
quantization_type=q_type,
|
||||||
|
class_predicate=_class_predicate,
|
||||||
)
|
)
|
||||||
# support hf model tree #957
|
# support hf model tree #957
|
||||||
quantized_config["quantization_config"] = quantized_config["quantization"]
|
quantized_config["quantization_config"] = quantized_config["quantization"]
|
||||||
@ -812,6 +830,7 @@ def convert(
|
|||||||
quantize: bool = False,
|
quantize: bool = False,
|
||||||
q_group_size: int = 64,
|
q_group_size: int = 64,
|
||||||
q_bits: int = 4,
|
q_bits: int = 4,
|
||||||
|
q_type: str = "affine",
|
||||||
dtype: str = "float16",
|
dtype: str = "float16",
|
||||||
upload_repo: str = None,
|
upload_repo: str = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
@ -845,7 +864,12 @@ def convert(
|
|||||||
print("[INFO] Quantizing")
|
print("[INFO] Quantizing")
|
||||||
model.load_weights(list(weights.items()))
|
model.load_weights(list(weights.items()))
|
||||||
weights, config = quantize_model(
|
weights, config = quantize_model(
|
||||||
model, config, q_group_size, q_bits, quant_predicate=quant_predicate
|
model,
|
||||||
|
config,
|
||||||
|
q_group_size,
|
||||||
|
q_bits,
|
||||||
|
q_type=q_type,
|
||||||
|
quant_predicate=quant_predicate,
|
||||||
)
|
)
|
||||||
|
|
||||||
if dequantize:
|
if dequantize:
|
||||||
|
Loading…
Reference in New Issue
Block a user