Add optional quantization types

This commit is contained in:
Angelos Katharopoulos 2024-12-17 22:24:41 -08:00
parent 845efddc8c
commit bc08025f41
2 changed files with 34 additions and 4 deletions

View File

@ -29,6 +29,12 @@ def configure_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
)
parser.add_argument(
"--q-type",
choices=["affine", "affine-packed"],
default="affine",
help="The type of quantization to apply",
)
parser.add_argument(
"--dtype",
help="Type to save the non-quantized parameters.",

View File

@ -528,6 +528,7 @@ def load_model(
model,
group_size=quantization["group_size"],
bits=quantization["bits"],
quantization_type=quantization["quantization_type"],
class_predicate=class_predicate,
)
@ -737,6 +738,7 @@ def quantize_model(
config: dict,
q_group_size: int,
q_bits: int,
q_type: str,
quant_predicate: Optional[
Callable[[str, nn.Module, dict], Union[bool, dict]]
] = None,
@ -749,6 +751,7 @@ def quantize_model(
config (dict): Model configuration.
q_group_size (int): Group size for quantization.
q_bits (int): Bits per weight for quantization.
q_type (str): Quantization type
quant_predicate (Callable): A callable that decides how
to quantize each layer based on the path.
Accepts the layer `path`, the `module` and the model `config`.
@ -759,11 +762,25 @@ def quantize_model(
Tuple: Tuple containing quantized weights and config.
"""
quantized_config = copy.deepcopy(config)
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
quantized_config["quantization"] = {
"group_size": q_group_size,
"bits": q_bits,
"quantization_type": q_type,
}
# Add any custom quantization parameters to the config as we go
def _class_predicate(p, m):
bool_or_params = quant_predicate(p, m, config)
if quant_predicate:
bool_or_params = quant_predicate(p, m, config)
else:
if isinstance(m, nn.Embedding):
bool_or_params = {
"group_size": q_group_size,
"bits": q_bits,
"quantization_type": "affine",
}
else:
bool_or_params = hasattr(m, "to_quantized")
quantized_config["quantization"][p] = bool_or_params
return bool_or_params
@ -771,7 +788,8 @@ def quantize_model(
model,
q_group_size,
q_bits,
class_predicate=_class_predicate if quant_predicate else None,
quantization_type=q_type,
class_predicate=_class_predicate,
)
# support hf model tree #957
quantized_config["quantization_config"] = quantized_config["quantization"]
@ -812,6 +830,7 @@ def convert(
quantize: bool = False,
q_group_size: int = 64,
q_bits: int = 4,
q_type: str = "affine",
dtype: str = "float16",
upload_repo: str = None,
revision: Optional[str] = None,
@ -845,7 +864,12 @@ def convert(
print("[INFO] Quantizing")
model.load_weights(list(weights.items()))
weights, config = quantize_model(
model, config, q_group_size, q_bits, quant_predicate=quant_predicate
model,
config,
q_group_size,
q_bits,
q_type=q_type,
quant_predicate=quant_predicate,
)
if dequantize: