refactor(hf_llm): moving phi2 example into hf_llm (#293)

* refactor: moving phi2 example into hf_llm

* chore: clean up

* chore: update phi2 model args so it can load args from config

* fix phi2 + nits + readme

* allow any HF repo, update README

* fix bug in llama

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Anchen
2024-01-11 12:29:12 -08:00
committed by GitHub
parent e74889d0fa
commit a2402116ae
15 changed files with 647 additions and 697 deletions

View File

@@ -1,52 +1,95 @@
# Copyright © 2023 Apple Inc.
import argparse
import copy
import glob
import json
from pathlib import Path
from typing import Dict, Tuple
import mlx.core as mx
import mlx.nn as nn
import transformers
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
from models import Model, ModelArgs
from utils import get_model_path, load
MAX_FILE_SIZE_GB = 15
def fetch_from_hub(model_path: str, local: bool):
if not local:
model_path = snapshot_download(
repo_id=model_path,
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
)
def configure_parser() -> argparse.ArgumentParser:
"""
Configures and returns the argument parser for the script.
Returns:
argparse.ArgumentParser: Configured argument parser.
"""
parser = argparse.ArgumentParser(
description="Convert Hugging Face model to MLX format"
)
parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.")
parser.add_argument(
"--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model."
)
parser.add_argument(
"-q", "--quantize", help="Generate a quantized model.", action="store_true"
)
parser.add_argument(
"--q-group-size", help="Group size for quantization.", type=int, default=64
)
parser.add_argument(
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
)
parser.add_argument(
"--dtype",
help="Type to save the parameters, ignored if -q is given.",
type=str,
choices=["float16", "bfloat16", "float32"],
default="float16",
)
parser.add_argument(
"--upload-repo",
help="The Hugging Face repo to upload the model to.",
type=str,
default=None,
)
return parser
def fetch_from_hub(
model_path: str,
) -> Tuple[Dict, dict, transformers.PreTrainedTokenizer]:
model_path = get_model_path(model_path)
weight_files = glob.glob(f"{model_path}/*.safetensors")
if len(weight_files) == 0:
raise FileNotFoundError("No safetensors found in {}".format(model_path))
if not weight_files:
raise FileNotFoundError(f"No safetensors found in {model_path}")
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
config = transformers.AutoConfig.from_pretrained(model_path)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_path,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
return weights, config.to_dict(), tokenizer
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
def quantize(weights: dict, config: dict, args: argparse.Namespace) -> tuple:
"""
Applies quantization to the model weights.
# Load the model:
model = Model(ModelArgs.from_dict(config))
Args:
weights (dict): Model weights.
config (dict): Model configuration.
args (argparse.Namespace): Command-line arguments.
Returns:
tuple: Tuple containing quantized weights and config.
"""
quantized_config = copy.deepcopy(config)
model, _ = load(args.hf_path)
model.load_weights(list(weights.items()))
# Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
# Update the config:
quantized_config["quantization"] = {
"group_size": args.q_group_size,
"bits": args.q_bits,
@@ -56,8 +99,18 @@ def quantize(weights, config, args):
return quantized_weights, quantized_config
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
max_file_size_bytes = max_file_size_gibibyte << 30
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
"""
Splits the weights into smaller shards.
Args:
weights (dict): Model weights.
max_file_size_gb (int): Maximum size of each shard in gigabytes.
Returns:
list: List of weight shards.
"""
max_file_size_bytes = max_file_size_gb << 30
shards = []
shard, shard_size = {}, 0
for k, v in weights.items():
@@ -71,17 +124,23 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
return shards
def upload_to_hub(path: str, name: str, hf_path: str):
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
"""
Uploads the model to Hugging Face hub.
Args:
path (str): Local path to the model.
upload_repo (str): Name of the HF repo to upload to.
hf_path (str): Path to the original Hugging Face model.
"""
import os
from huggingface_hub import HfApi, ModelCard, logging
repo_id = f"mlx-community/{name}"
card = ModelCard.load(hf_path)
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
card.text = f"""
# {name}
# {upload_repo}
This model was converted to MLX format from [`{hf_path}`]().
Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
## Use with mlx
@@ -97,72 +156,20 @@ python generate.py --model {repo_id} --prompt "My name is"
logging.set_verbosity_info()
api = HfApi()
api.create_repo(repo_id=repo_id, exist_ok=True)
api.create_repo(repo_id=upload_repo, exist_ok=True)
api.upload_folder(
folder_path=path,
repo_id=repo_id,
repo_id=upload_repo,
repo_type="model",
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert Hugging Face model to MLX format"
)
parser.add_argument(
"--hf-path",
type=str,
help="Path to the Hugging Face model.",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="Path to save the MLX model.",
)
parser.add_argument(
"-q",
"--quantize",
help="Generate a quantized model.",
action="store_true",
)
parser.add_argument(
"--q-group-size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q-bits",
help="Bits per weight for quantization.",
type=int,
default=4,
)
parser.add_argument(
"--dtype",
help="Type to save the parameters, ignored if -q is given.",
type=str,
choices=["float16", "bfloat16", "float32"],
default="float16",
)
parser.add_argument(
"--upload-name",
help="The name of model to upload to Hugging Face MLX Community",
type=str,
default=None,
)
parser.add_argument(
"-l",
"--local",
action="store_true",
help="Whether the hf-path points to a local filesystem.",
default=False,
)
parser = configure_parser()
args = parser.parse_args()
print("[INFO] Loading")
weights, config, tokenizer = fetch_from_hub(args.hf_path, args.local)
weights, config, tokenizer = fetch_from_hub(args.hf_path)
dtype = mx.float16 if args.quantize else getattr(mx, args.dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()}
@@ -179,5 +186,5 @@ if __name__ == "__main__":
with open(mlx_path / "config.json", "w") as fid:
json.dump(config, fid, indent=4)
if args.upload_name is not None and not args.local:
upload_to_hub(mlx_path, args.upload_name, args.hf_path)
if args.upload_repo is not None:
upload_to_hub(mlx_path, args.upload_repo, args.hf_path)