Mlx llm package (#301)

* fix converter

* add recursive files

* remove gitignore

* remove gitignore

* add packages properly

* read me update

* remove dup readme

* relative

* fix convert

* fix community name

* fix url

* version
This commit is contained in:
Awni Hannun
2024-01-12 10:25:56 -08:00
committed by GitHub
parent 2b61d9deb6
commit c6440416a2
17 changed files with 270 additions and 388 deletions

205
llms/mlx_lm/convert.py Normal file
View File

@@ -0,0 +1,205 @@
import argparse
import copy
import glob
import json
from pathlib import Path
from typing import Dict, Tuple
import mlx.core as mx
import mlx.nn as nn
import transformers
from mlx.utils import tree_flatten
from .utils import get_model_path, load
MAX_FILE_SIZE_GB = 15
def configure_parser() -> argparse.ArgumentParser:
"""
Configures and returns the argument parser for the script.
Returns:
argparse.ArgumentParser: Configured argument parser.
"""
parser = argparse.ArgumentParser(
description="Convert Hugging Face model to MLX format"
)
parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.")
parser.add_argument(
"--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model."
)
parser.add_argument(
"-q", "--quantize", help="Generate a quantized model.", action="store_true"
)
parser.add_argument(
"--q-group-size", help="Group size for quantization.", type=int, default=64
)
parser.add_argument(
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
)
parser.add_argument(
"--dtype",
help="Type to save the parameters, ignored if -q is given.",
type=str,
choices=["float16", "bfloat16", "float32"],
default="float16",
)
parser.add_argument(
"--upload-repo",
help="The Hugging Face repo to upload the model to.",
type=str,
default=None,
)
return parser
def fetch_from_hub(
model_path: str,
) -> Tuple[Dict, dict, transformers.PreTrainedTokenizer]:
model_path = get_model_path(model_path)
weight_files = glob.glob(f"{model_path}/*.safetensors")
if not weight_files:
raise FileNotFoundError(f"No safetensors found in {model_path}")
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
config = transformers.AutoConfig.from_pretrained(model_path)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
return weights, config.to_dict(), tokenizer
def quantize_model(
weights: dict, config: dict, hf_path: str, q_group_size: int, q_bits: int
) -> tuple:
"""
Applies quantization to the model weights.
Args:
weights (dict): Model weights.
config (dict): Model configuration.
hf_path (str): HF model path..
q_group_size (int): Group size for quantization.
q_bits (int): Bits per weight for quantization.
Returns:
tuple: Tuple containing quantized weights and config.
"""
quantized_config = copy.deepcopy(config)
model, _ = load(hf_path)
model.load_weights(list(weights.items()))
nn.QuantizedLinear.quantize_module(model, q_group_size, q_bits)
quantized_config["quantization"] = {
"group_size": q_group_size,
"bits": q_bits,
}
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
"""
Splits the weights into smaller shards.
Args:
weights (dict): Model weights.
max_file_size_gb (int): Maximum size of each shard in gigabytes.
Returns:
list: List of weight shards.
"""
max_file_size_bytes = max_file_size_gb << 30
shards = []
shard, shard_size = {}, 0
for k, v in weights.items():
estimated_size = v.size * v.dtype.size
if shard_size + estimated_size > max_file_size_bytes:
shards.append(shard)
shard, shard_size = {}, 0
shard[k] = v
shard_size += estimated_size
shards.append(shard)
return shards
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
"""
Uploads the model to Hugging Face hub.
Args:
path (str): Local path to the model.
upload_repo (str): Name of the HF repo to upload to.
hf_path (str): Path to the original Hugging Face model.
"""
import os
from huggingface_hub import HfApi, ModelCard, logging
card = ModelCard.load(hf_path)
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
card.text = f"""
# {upload_repo}
This model was converted to MLX format from [`{hf_path}`]().
Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
## Use with mlx
```bash
pip install mlx
git clone https://github.com/ml-explore/mlx-examples.git
cd mlx-examples/llms/hf_llm
python generate.py --model {upload_repo} --prompt "My name is"
```
"""
card.save(os.path.join(path, "README.md"))
logging.set_verbosity_info()
api = HfApi()
api.create_repo(repo_id=upload_repo, exist_ok=True)
api.upload_folder(
folder_path=path,
repo_id=upload_repo,
repo_type="model",
)
def convert(
hf_path: str,
mlx_path: str = "mlx_model",
quantize: bool = False,
q_group_size: int = 64,
q_bits: int = 4,
dtype: str = "float16",
upload_repo: str = None,
):
print("[INFO] Loading")
weights, config, tokenizer = fetch_from_hub(hf_path)
dtype = mx.float16 if quantize else getattr(mx, dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()}
if quantize:
print("[INFO] Quantizing")
weights, config = quantize_model(weights, config, hf_path, q_group_size, q_bits)
mlx_path = Path(mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
shards = make_shards(weights)
for i, shard in enumerate(shards):
mx.save_safetensors(str(mlx_path / f"weights.{i:02d}.safetensors"), shard)
tokenizer.save_pretrained(mlx_path)
with open(mlx_path / "config.json", "w") as fid:
json.dump(config, fid, indent=4)
if upload_repo is not None:
upload_to_hub(mlx_path, upload_repo, hf_path)
if __name__ == "__main__":
parser = configure_parser()
args = parser.parse_args()
convert(**vars(args))