mlx-examples/llms/mlx_lm/utils.py
someone 2287294723
fix mlx_lm generator for chinese (#321)
* fix generator for chinese

* add REPLACEMENT_CHAR

---------

Co-authored-by: cg <cg@qq.com>
2024-01-16 07:13:33 -08:00

193 lines
5.4 KiB
Python

import glob
import json
import logging
from pathlib import Path
from typing import Generator, Tuple
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, PreTrainedTokenizer
# Local imports
from .models import llama, mixtral, phi2
from .models.base import BaseModelArgs
# Constants
MODEL_MAPPING = {
"llama": llama,
"mistral": llama, # mistral is compatible with llama
"mixtral": mixtral,
"phi": phi2,
}
linear_class_predicate = (
lambda m: isinstance(m, nn.Linear) and m.weight.shape[0] % 32 == 0
) # TODO remove this once we support quantization for non-multiples of 32
def _get_classes(config: dict):
"""
Retrieve the model and model args classes based on the configuration.
Args:
config (dict): The model configuration.
Returns:
A tuple containing the Model class and the ModelArgs class.
"""
model_type = config["model_type"]
if model_type not in MODEL_MAPPING:
msg = f"Model type {model_type} not supported."
logging.error(msg)
raise ValueError(msg)
arch = MODEL_MAPPING[model_type]
return arch.Model, arch.ModelArgs
def get_model_path(path_or_hf_repo: str) -> Path:
"""
Ensures the model is available locally. If the path does not exist locally,
it is downloaded from the Hugging Face Hub.
Args:
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
Returns:
Path: The path to the model.
"""
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
allow_patterns=["*.json", "*.safetensors", "*.py", "tokenizer.model"],
)
)
return model_path
def generate_step(
prompt: mx.array, model: nn.Module, temp: float = 0.0
) -> Generator[mx.array, None, None]:
"""
A generator producing text based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling. If temp is 0, use max sampling.
Yields:
Generator[mx.array]: A generator producing one token per call.
"""
def sample(logits: mx.array) -> mx.array:
return (
mx.argmax(logits, axis=-1)
if temp == 0
else mx.random.categorical(logits * (1 / temp))
)
y = prompt
cache = None
while True:
logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :]
y = sample(logits)
yield y
def generate(
model: nn.Module,
tokenizer: PreTrainedTokenizer,
prompt: str,
temp: float = 0.0,
max_tokens: int = 100,
verbose: bool = False,
) -> str:
"""
Generate text from the model.
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt.
temp (float): The temperature for sampling (default 0).
max_tokens (int): The maximum number of tokens (default 100).
"""
prompt = mx.array(tokenizer.encode(prompt))
tokens = []
skip = 0
REPLACEMENT_CHAR = '\ufffd'
for token, _ in zip(generate_step(prompt, model, temp), range(max_tokens)):
if token == tokenizer.eos_token_id:
break
tokens.append(token.item())
if verbose:
s = tokenizer.decode(tokens)
if REPLACEMENT_CHAR not in s:
print(s[skip:], end="", flush=True)
skip = len(s)
tokens = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, '')
if verbose:
print(tokens[skip:], flush=True)
return tokens
def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]:
"""
Load the model from a given path or a huggingface repository.
Args:
path_or_hf_repo (str): The path or the huggingface repository to load the model from.
Returns:
Tuple[nn.Module, PreTrainedTokenizer]: The loaded model and tokenizer.
Raises:
FileNotFoundError: If config file or safetensors are not found.
ValueError: If model class or args class are not found.
"""
model_path = get_model_path(path_or_hf_repo)
try:
with open(model_path / "config.json", "r") as f:
config = json.load(f)
quantization = config.get("quantization", None)
except FileNotFoundError:
logging.error(f"Config file not found in {model_path}")
raise
weight_files = glob.glob(str(model_path / "*.safetensors"))
if not weight_files:
logging.error(f"No safetensors found in {model_path}")
raise FileNotFoundError(f"No safetensors found in {model_path}")
weights = {}
for wf in weight_files:
weights.update(mx.load(wf))
model_class, model_args_class = _get_classes(config=config)
model_args = model_args_class.from_dict(config)
model = model_class(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(
model,
**quantization,
linear_class_predicate=linear_class_predicate,
)
model.load_weights(list(weights.items()))
mx.eval(model.parameters())
tokenizer = AutoTokenizer.from_pretrained(model_path)
return model, tokenizer