2024-02-28 00:47:56 +08:00
|
|
|
# Copyright © 2023-2024 Apple Inc.
|
2024-01-04 07:13:26 +08:00
|
|
|
|
2024-02-28 00:47:56 +08:00
|
|
|
import argparse
|
2025-02-27 03:32:36 +08:00
|
|
|
from enum import Enum
|
2024-01-13 02:25:56 +08:00
|
|
|
|
2025-02-27 03:32:36 +08:00
|
|
|
from .utils import convert, mixed_2_6, mixed_3_6
|
|
|
|
|
|
|
|
|
|
|
|
class MixedQuants(Enum):
|
|
|
|
mixed_3_6 = "mixed_3_6"
|
|
|
|
mixed_2_6 = "mixed_2_6"
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def recipe_names(cls):
|
|
|
|
return [member.name for member in cls]
|
|
|
|
|
|
|
|
|
|
|
|
def quant_args(arg):
|
|
|
|
try:
|
|
|
|
return MixedQuants[arg].value
|
|
|
|
except KeyError:
|
|
|
|
raise argparse.ArgumentTypeError(
|
|
|
|
f"Invalid q-recipe {arg!r}. Choose from: {MixedQuants.recipe_names()}"
|
|
|
|
)
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
|
|
|
|
def configure_parser() -> argparse.ArgumentParser:
|
|
|
|
"""
|
|
|
|
Configures and returns the argument parser for the script.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
argparse.ArgumentParser: Configured argument parser.
|
|
|
|
"""
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
description="Convert Hugging Face model to MLX format"
|
|
|
|
)
|
|
|
|
|
|
|
|
parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.")
|
|
|
|
parser.add_argument(
|
|
|
|
"--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model."
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"-q", "--quantize", help="Generate a quantized model.", action="store_true"
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--q-group-size", help="Group size for quantization.", type=int, default=64
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
|
|
|
|
)
|
2025-02-27 03:32:36 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--quant-predicate",
|
|
|
|
help=f"Mixed-bit quantization recipe. Choices: {MixedQuants.recipe_names()}",
|
|
|
|
type=quant_args,
|
|
|
|
required=False,
|
|
|
|
)
|
2024-01-12 04:29:12 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--dtype",
|
2024-10-23 00:56:45 +08:00
|
|
|
help="Type to save the non-quantized parameters.",
|
2024-01-12 04:29:12 +08:00
|
|
|
type=str,
|
|
|
|
choices=["float16", "bfloat16", "float32"],
|
|
|
|
default="float16",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--upload-repo",
|
|
|
|
help="The Hugging Face repo to upload the model to.",
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
)
|
2024-03-20 10:50:08 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"-d",
|
|
|
|
"--dequantize",
|
|
|
|
help="Dequantize a quantized model.",
|
|
|
|
action="store_true",
|
|
|
|
default=False,
|
|
|
|
)
|
2024-01-12 04:29:12 +08:00
|
|
|
return parser
|
2024-01-04 07:13:26 +08:00
|
|
|
|
|
|
|
|
2024-04-17 07:08:49 +08:00
|
|
|
def main():
|
2024-01-13 02:25:56 +08:00
|
|
|
parser = configure_parser()
|
|
|
|
args = parser.parse_args()
|
|
|
|
convert(**vars(args))
|
2024-04-17 07:08:49 +08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|