mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
refactor(hf_llm): moving phi2 example into hf_llm (#293)
* refactor: moving phi2 example into hf_llm * chore: clean up * chore: update phi2 model args so it can load args from config * fix phi2 + nits + readme * allow any HF repo, update README * fix bug in llama --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -1,52 +1,95 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import glob
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import transformers
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_flatten
|
||||
from models import Model, ModelArgs
|
||||
from utils import get_model_path, load
|
||||
|
||||
MAX_FILE_SIZE_GB = 15
|
||||
|
||||
|
||||
def fetch_from_hub(model_path: str, local: bool):
|
||||
if not local:
|
||||
model_path = snapshot_download(
|
||||
repo_id=model_path,
|
||||
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
|
||||
)
|
||||
def configure_parser() -> argparse.ArgumentParser:
|
||||
"""
|
||||
Configures and returns the argument parser for the script.
|
||||
|
||||
Returns:
|
||||
argparse.ArgumentParser: Configured argument parser.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert Hugging Face model to MLX format"
|
||||
)
|
||||
|
||||
parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.")
|
||||
parser.add_argument(
|
||||
"--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-q", "--quantize", help="Generate a quantized model.", action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q-group-size", help="Group size for quantization.", type=int, default=64
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
help="Type to save the parameters, ignored if -q is given.",
|
||||
type=str,
|
||||
choices=["float16", "bfloat16", "float32"],
|
||||
default="float16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload-repo",
|
||||
help="The Hugging Face repo to upload the model to.",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def fetch_from_hub(
|
||||
model_path: str,
|
||||
) -> Tuple[Dict, dict, transformers.PreTrainedTokenizer]:
|
||||
model_path = get_model_path(model_path)
|
||||
|
||||
weight_files = glob.glob(f"{model_path}/*.safetensors")
|
||||
if len(weight_files) == 0:
|
||||
raise FileNotFoundError("No safetensors found in {}".format(model_path))
|
||||
if not weight_files:
|
||||
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
||||
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf).items())
|
||||
|
||||
config = transformers.AutoConfig.from_pretrained(model_path)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
return weights, config.to_dict(), tokenizer
|
||||
|
||||
|
||||
def quantize(weights, config, args):
|
||||
quantized_config = copy.deepcopy(config)
|
||||
def quantize(weights: dict, config: dict, args: argparse.Namespace) -> tuple:
|
||||
"""
|
||||
Applies quantization to the model weights.
|
||||
|
||||
# Load the model:
|
||||
model = Model(ModelArgs.from_dict(config))
|
||||
Args:
|
||||
weights (dict): Model weights.
|
||||
config (dict): Model configuration.
|
||||
args (argparse.Namespace): Command-line arguments.
|
||||
|
||||
Returns:
|
||||
tuple: Tuple containing quantized weights and config.
|
||||
"""
|
||||
quantized_config = copy.deepcopy(config)
|
||||
model, _ = load(args.hf_path)
|
||||
model.load_weights(list(weights.items()))
|
||||
|
||||
# Quantize the model:
|
||||
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
|
||||
|
||||
# Update the config:
|
||||
quantized_config["quantization"] = {
|
||||
"group_size": args.q_group_size,
|
||||
"bits": args.q_bits,
|
||||
@@ -56,8 +99,18 @@ def quantize(weights, config, args):
|
||||
return quantized_weights, quantized_config
|
||||
|
||||
|
||||
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
||||
max_file_size_bytes = max_file_size_gibibyte << 30
|
||||
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
|
||||
"""
|
||||
Splits the weights into smaller shards.
|
||||
|
||||
Args:
|
||||
weights (dict): Model weights.
|
||||
max_file_size_gb (int): Maximum size of each shard in gigabytes.
|
||||
|
||||
Returns:
|
||||
list: List of weight shards.
|
||||
"""
|
||||
max_file_size_bytes = max_file_size_gb << 30
|
||||
shards = []
|
||||
shard, shard_size = {}, 0
|
||||
for k, v in weights.items():
|
||||
@@ -71,17 +124,23 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
||||
return shards
|
||||
|
||||
|
||||
def upload_to_hub(path: str, name: str, hf_path: str):
|
||||
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
||||
"""
|
||||
Uploads the model to Hugging Face hub.
|
||||
|
||||
Args:
|
||||
path (str): Local path to the model.
|
||||
upload_repo (str): Name of the HF repo to upload to.
|
||||
hf_path (str): Path to the original Hugging Face model.
|
||||
"""
|
||||
import os
|
||||
|
||||
from huggingface_hub import HfApi, ModelCard, logging
|
||||
|
||||
repo_id = f"mlx-community/{name}"
|
||||
|
||||
card = ModelCard.load(hf_path)
|
||||
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
|
||||
card.text = f"""
|
||||
# {name}
|
||||
# {upload_repo}
|
||||
This model was converted to MLX format from [`{hf_path}`]().
|
||||
Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
|
||||
## Use with mlx
|
||||
@@ -97,72 +156,20 @@ python generate.py --model {repo_id} --prompt "My name is"
|
||||
logging.set_verbosity_info()
|
||||
|
||||
api = HfApi()
|
||||
api.create_repo(repo_id=repo_id, exist_ok=True)
|
||||
api.create_repo(repo_id=upload_repo, exist_ok=True)
|
||||
api.upload_folder(
|
||||
folder_path=path,
|
||||
repo_id=repo_id,
|
||||
repo_id=upload_repo,
|
||||
repo_type="model",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert Hugging Face model to MLX format"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-path",
|
||||
type=str,
|
||||
help="Path to the Hugging Face model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlx-path",
|
||||
type=str,
|
||||
default="mlx_model",
|
||||
help="Path to save the MLX model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-q",
|
||||
"--quantize",
|
||||
help="Generate a quantized model.",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q-group-size",
|
||||
help="Group size for quantization.",
|
||||
type=int,
|
||||
default=64,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q-bits",
|
||||
help="Bits per weight for quantization.",
|
||||
type=int,
|
||||
default=4,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
help="Type to save the parameters, ignored if -q is given.",
|
||||
type=str,
|
||||
choices=["float16", "bfloat16", "float32"],
|
||||
default="float16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload-name",
|
||||
help="The name of model to upload to Hugging Face MLX Community",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--local",
|
||||
action="store_true",
|
||||
help="Whether the hf-path points to a local filesystem.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
parser = configure_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
print("[INFO] Loading")
|
||||
weights, config, tokenizer = fetch_from_hub(args.hf_path, args.local)
|
||||
weights, config, tokenizer = fetch_from_hub(args.hf_path)
|
||||
|
||||
dtype = mx.float16 if args.quantize else getattr(mx, args.dtype)
|
||||
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
||||
@@ -179,5 +186,5 @@ if __name__ == "__main__":
|
||||
with open(mlx_path / "config.json", "w") as fid:
|
||||
json.dump(config, fid, indent=4)
|
||||
|
||||
if args.upload_name is not None and not args.local:
|
||||
upload_to_hub(mlx_path, args.upload_name, args.hf_path)
|
||||
if args.upload_repo is not None:
|
||||
upload_to_hub(mlx_path, args.upload_repo, args.hf_path)
|
||||
|
||||
Reference in New Issue
Block a user