mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-11-01 11:38:08 +08:00
Mlx llm package (#301)
* fix converter * add recursive files * remove gitignore * remove gitignore * add packages properly * read me update * remove dup readme * relative * fix convert * fix community name * fix url * version
This commit is contained in:
205
llms/mlx_lm/convert.py
Normal file
205
llms/mlx_lm/convert.py
Normal file
@@ -0,0 +1,205 @@
|
||||
import argparse
|
||||
import copy
|
||||
import glob
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import transformers
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
from .utils import get_model_path, load
|
||||
|
||||
MAX_FILE_SIZE_GB = 15
|
||||
|
||||
|
||||
def configure_parser() -> argparse.ArgumentParser:
|
||||
"""
|
||||
Configures and returns the argument parser for the script.
|
||||
|
||||
Returns:
|
||||
argparse.ArgumentParser: Configured argument parser.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert Hugging Face model to MLX format"
|
||||
)
|
||||
|
||||
parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.")
|
||||
parser.add_argument(
|
||||
"--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-q", "--quantize", help="Generate a quantized model.", action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q-group-size", help="Group size for quantization.", type=int, default=64
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
help="Type to save the parameters, ignored if -q is given.",
|
||||
type=str,
|
||||
choices=["float16", "bfloat16", "float32"],
|
||||
default="float16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload-repo",
|
||||
help="The Hugging Face repo to upload the model to.",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def fetch_from_hub(
|
||||
model_path: str,
|
||||
) -> Tuple[Dict, dict, transformers.PreTrainedTokenizer]:
|
||||
model_path = get_model_path(model_path)
|
||||
|
||||
weight_files = glob.glob(f"{model_path}/*.safetensors")
|
||||
if not weight_files:
|
||||
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
||||
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf).items())
|
||||
|
||||
config = transformers.AutoConfig.from_pretrained(model_path)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
return weights, config.to_dict(), tokenizer
|
||||
|
||||
|
||||
def quantize_model(
|
||||
weights: dict, config: dict, hf_path: str, q_group_size: int, q_bits: int
|
||||
) -> tuple:
|
||||
"""
|
||||
Applies quantization to the model weights.
|
||||
|
||||
Args:
|
||||
weights (dict): Model weights.
|
||||
config (dict): Model configuration.
|
||||
hf_path (str): HF model path..
|
||||
q_group_size (int): Group size for quantization.
|
||||
q_bits (int): Bits per weight for quantization.
|
||||
|
||||
Returns:
|
||||
tuple: Tuple containing quantized weights and config.
|
||||
"""
|
||||
quantized_config = copy.deepcopy(config)
|
||||
model, _ = load(hf_path)
|
||||
model.load_weights(list(weights.items()))
|
||||
|
||||
nn.QuantizedLinear.quantize_module(model, q_group_size, q_bits)
|
||||
quantized_config["quantization"] = {
|
||||
"group_size": q_group_size,
|
||||
"bits": q_bits,
|
||||
}
|
||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||
|
||||
return quantized_weights, quantized_config
|
||||
|
||||
|
||||
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
|
||||
"""
|
||||
Splits the weights into smaller shards.
|
||||
|
||||
Args:
|
||||
weights (dict): Model weights.
|
||||
max_file_size_gb (int): Maximum size of each shard in gigabytes.
|
||||
|
||||
Returns:
|
||||
list: List of weight shards.
|
||||
"""
|
||||
max_file_size_bytes = max_file_size_gb << 30
|
||||
shards = []
|
||||
shard, shard_size = {}, 0
|
||||
for k, v in weights.items():
|
||||
estimated_size = v.size * v.dtype.size
|
||||
if shard_size + estimated_size > max_file_size_bytes:
|
||||
shards.append(shard)
|
||||
shard, shard_size = {}, 0
|
||||
shard[k] = v
|
||||
shard_size += estimated_size
|
||||
shards.append(shard)
|
||||
return shards
|
||||
|
||||
|
||||
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
||||
"""
|
||||
Uploads the model to Hugging Face hub.
|
||||
|
||||
Args:
|
||||
path (str): Local path to the model.
|
||||
upload_repo (str): Name of the HF repo to upload to.
|
||||
hf_path (str): Path to the original Hugging Face model.
|
||||
"""
|
||||
import os
|
||||
|
||||
from huggingface_hub import HfApi, ModelCard, logging
|
||||
|
||||
card = ModelCard.load(hf_path)
|
||||
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
|
||||
card.text = f"""
|
||||
# {upload_repo}
|
||||
This model was converted to MLX format from [`{hf_path}`]().
|
||||
Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
|
||||
## Use with mlx
|
||||
```bash
|
||||
pip install mlx
|
||||
git clone https://github.com/ml-explore/mlx-examples.git
|
||||
cd mlx-examples/llms/hf_llm
|
||||
python generate.py --model {upload_repo} --prompt "My name is"
|
||||
```
|
||||
"""
|
||||
card.save(os.path.join(path, "README.md"))
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
api = HfApi()
|
||||
api.create_repo(repo_id=upload_repo, exist_ok=True)
|
||||
api.upload_folder(
|
||||
folder_path=path,
|
||||
repo_id=upload_repo,
|
||||
repo_type="model",
|
||||
)
|
||||
|
||||
|
||||
def convert(
|
||||
hf_path: str,
|
||||
mlx_path: str = "mlx_model",
|
||||
quantize: bool = False,
|
||||
q_group_size: int = 64,
|
||||
q_bits: int = 4,
|
||||
dtype: str = "float16",
|
||||
upload_repo: str = None,
|
||||
):
|
||||
print("[INFO] Loading")
|
||||
weights, config, tokenizer = fetch_from_hub(hf_path)
|
||||
dtype = mx.float16 if quantize else getattr(mx, dtype)
|
||||
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
||||
if quantize:
|
||||
print("[INFO] Quantizing")
|
||||
weights, config = quantize_model(weights, config, hf_path, q_group_size, q_bits)
|
||||
|
||||
mlx_path = Path(mlx_path)
|
||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||
shards = make_shards(weights)
|
||||
for i, shard in enumerate(shards):
|
||||
mx.save_safetensors(str(mlx_path / f"weights.{i:02d}.safetensors"), shard)
|
||||
tokenizer.save_pretrained(mlx_path)
|
||||
with open(mlx_path / "config.json", "w") as fid:
|
||||
json.dump(config, fid, indent=4)
|
||||
|
||||
if upload_repo is not None:
|
||||
upload_to_hub(mlx_path, upload_repo, hf_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = configure_parser()
|
||||
args = parser.parse_args()
|
||||
convert(**vars(args))
|
||||
Reference in New Issue
Block a user