mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00

* refactor: moving phi2 example into hf_llm * chore: clean up * chore: update phi2 model args so it can load args from config * fix phi2 + nits + readme * allow any HF repo, update README * fix bug in llama --------- Co-authored-by: Awni Hannun <awni@apple.com>
142 lines
3.9 KiB
Python
142 lines
3.9 KiB
Python
import glob
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Generator, Tuple
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
# Local imports
|
|
import models.llama as llama
|
|
import models.phi2 as phi2
|
|
from huggingface_hub import snapshot_download
|
|
from models.base import BaseModelArgs
|
|
from transformers import AutoTokenizer, PreTrainedTokenizer
|
|
|
|
# Constants
|
|
MODEL_MAPPING = {
|
|
"llama": llama,
|
|
"mistral": llama, # mistral is compatible with llama
|
|
"phi": phi2,
|
|
}
|
|
|
|
|
|
def _get_classes(config: dict):
|
|
"""
|
|
Retrieve the model and model args classes based on the configuration.
|
|
|
|
Args:
|
|
config (dict): The model configuration.
|
|
|
|
Returns:
|
|
A tuple containing the Model class and the ModelArgs class.
|
|
"""
|
|
model_type = config["model_type"]
|
|
if model_type not in MODEL_MAPPING:
|
|
msg = f"Model type {model_type} not supported."
|
|
logging.error(msg)
|
|
raise ValueError(msg)
|
|
|
|
arch = MODEL_MAPPING[model_type]
|
|
return arch.Model, arch.ModelArgs
|
|
|
|
|
|
def get_model_path(path_or_hf_repo: str) -> Path:
|
|
"""
|
|
Ensures the model is available locally. If the path does not exist locally,
|
|
it is downloaded from the Hugging Face Hub.
|
|
|
|
Args:
|
|
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
|
|
|
|
Returns:
|
|
Path: The path to the model.
|
|
"""
|
|
model_path = Path(path_or_hf_repo)
|
|
if not model_path.exists():
|
|
model_path = Path(
|
|
snapshot_download(
|
|
repo_id=path_or_hf_repo,
|
|
allow_patterns=["*.json", "*.safetensors", "*.py", "tokenizer.model"],
|
|
)
|
|
)
|
|
return model_path
|
|
|
|
|
|
def generate(
|
|
prompt: mx.array, model: nn.Module, temp: float = 0.0
|
|
) -> Generator[mx.array, None, None]:
|
|
"""
|
|
Generate text based on the given prompt and model.
|
|
|
|
Args:
|
|
prompt (mx.array): The input prompt.
|
|
model (nn.Module): The model to use for generation.
|
|
temp (float): The temperature for sampling. If temp is 0, use max sampling.
|
|
|
|
Yields:
|
|
mx.array: The generated text.
|
|
"""
|
|
|
|
def sample(logits: mx.array) -> mx.array:
|
|
return (
|
|
mx.argmax(logits, axis=-1)
|
|
if temp == 0
|
|
else mx.random.categorical(logits * (1 / temp))
|
|
)
|
|
|
|
y = prompt
|
|
cache = None
|
|
while True:
|
|
logits, cache = model(y[None], cache=cache)
|
|
logits = logits[:, -1, :]
|
|
y = sample(logits)
|
|
yield y
|
|
|
|
|
|
def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
|
"""
|
|
Load the model from a given path or a huggingface repository.
|
|
|
|
Args:
|
|
path_or_hf_repo (str): The path or the huggingface repository to load the model from.
|
|
|
|
Returns:
|
|
Tuple[nn.Module, PreTrainedTokenizer]: The loaded model and tokenizer.
|
|
|
|
Raises:
|
|
FileNotFoundError: If config file or safetensors are not found.
|
|
ValueError: If model class or args class are not found.
|
|
"""
|
|
model_path = get_model_path(path_or_hf_repo)
|
|
|
|
try:
|
|
with open(model_path / "config.json", "r") as f:
|
|
config = json.load(f)
|
|
quantization = config.get("quantization", None)
|
|
except FileNotFoundError:
|
|
logging.error(f"Config file not found in {model_path}")
|
|
raise
|
|
weight_files = glob.glob(str(model_path / "*.safetensors"))
|
|
if not weight_files:
|
|
logging.error(f"No safetensors found in {model_path}")
|
|
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
|
weights = {}
|
|
for wf in weight_files:
|
|
weights.update(mx.load(wf))
|
|
|
|
model_class, model_args_class = _get_classes(config=config)
|
|
|
|
model_args = model_args_class.from_dict(config)
|
|
model = model_class(model_args)
|
|
|
|
if quantization is not None:
|
|
nn.QuantizedLinear.quantize_module(model, **quantization)
|
|
|
|
model.load_weights(list(weights.items()))
|
|
|
|
mx.eval(model.parameters())
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
return model, tokenizer
|