import argparse import copy import glob import json import shutil from pathlib import Path from typing import Tuple import mlx.core as mx import mlx.nn as nn from mlx.utils import tree_flatten from .utils import ( fetch_from_hub, get_model_path, linear_class_predicate, save_weights, upload_to_hub, ) def configure_parser() -> argparse.ArgumentParser: """ Configures and returns the argument parser for the script. Returns: argparse.ArgumentParser: Configured argument parser. """ parser = argparse.ArgumentParser( description="Convert Hugging Face model to MLX format" ) parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.") parser.add_argument( "--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model." ) parser.add_argument( "-q", "--quantize", help="Generate a quantized model.", action="store_true" ) parser.add_argument( "--q-group-size", help="Group size for quantization.", type=int, default=64 ) parser.add_argument( "--q-bits", help="Bits per weight for quantization.", type=int, default=4 ) parser.add_argument( "--dtype", help="Type to save the parameters, ignored if -q is given.", type=str, choices=["float16", "bfloat16", "float32"], default="float16", ) parser.add_argument( "--upload-repo", help="The Hugging Face repo to upload the model to.", type=str, default=None, ) return parser def quantize_model( model: nn.Module, config: dict, q_group_size: int, q_bits: int ) -> Tuple: """ Applies quantization to the model weights. Args: model (nn.Module): The model to be quantized. config (dict): Model configuration. q_group_size (int): Group size for quantization. q_bits (int): Bits per weight for quantization. Returns: Tuple: Tuple containing quantized weights and config. """ quantized_config = copy.deepcopy(config) nn.QuantizedLinear.quantize_module( model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate ) quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits} quantized_weights = dict(tree_flatten(model.parameters())) return quantized_weights, quantized_config def convert( hf_path: str, mlx_path: str = "mlx_model", quantize: bool = False, q_group_size: int = 64, q_bits: int = 4, dtype: str = "float16", upload_repo: str = None, ): print("[INFO] Loading") model_path = get_model_path(hf_path) model, config, tokenizer = fetch_from_hub(model_path, lazy=True) weights = dict(tree_flatten(model.parameters())) dtype = mx.float16 if quantize else getattr(mx, dtype) weights = {k: v.astype(dtype) for k, v in weights.items()} if quantize: print("[INFO] Quantizing") model.load_weights(list(weights.items())) weights, config = quantize_model(model, config, q_group_size, q_bits) if isinstance(mlx_path, str): mlx_path = Path(mlx_path) del model save_weights(mlx_path, weights, donate_weights=True) py_files = glob.glob(str(model_path / "*.py")) for file in py_files: shutil.copy(file, mlx_path) tokenizer.save_pretrained(mlx_path) with open(mlx_path / "config.json", "w") as fid: json.dump(config, fid, indent=4) if upload_repo is not None: upload_to_hub(mlx_path, upload_repo, hf_path) if __name__ == "__main__": parser = configure_parser() args = parser.parse_args() convert(**vars(args))