mlx-examples/llms/mlx_lm/utils.py
Ivan Fioravanti c45c2311bd
Add colorized output option to generate script (#347)
* Add colorized output option to generate script

Two new functions were added to the script that allow output to be colorized based on the T[0] probability. Changes were made to the `generate_step` function in utils.py to permit colorization. Additionally, an argument for colorization was introduced to the command-line parser.

* Rename 'colorize' parameter with 'return_probability' in generate_step
2024-01-23 05:25:44 -08:00

231 lines
6.5 KiB
Python

import glob
import json
import logging
from pathlib import Path
from typing import Generator, Tuple
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, PreTrainedTokenizer
# Local imports
from .models import llama, mixtral, phi2, qwen
# Constants
MODEL_MAPPING = {
"llama": llama,
"mistral": llama, # mistral is compatible with llama
"mixtral": mixtral,
"phi": phi2,
"qwen": qwen,
}
linear_class_predicate = (
lambda m: isinstance(m, nn.Linear)
and m.weight.shape[0]
!= 8 # avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models
)
def _get_classes(config: dict):
"""
Retrieve the model and model args classes based on the configuration.
Args:
config (dict): The model configuration.
Returns:
A tuple containing the Model class and the ModelArgs class.
"""
model_type = config["model_type"]
if model_type not in MODEL_MAPPING:
msg = f"Model type {model_type} not supported."
logging.error(msg)
raise ValueError(msg)
arch = MODEL_MAPPING[model_type]
return arch.Model, arch.ModelArgs
def get_model_path(path_or_hf_repo: str) -> Path:
"""
Ensures the model is available locally. If the path does not exist locally,
it is downloaded from the Hugging Face Hub.
Args:
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
Returns:
Path: The path to the model.
"""
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
allow_patterns=[
"*.json",
"*.safetensors",
"*.py",
"tokenizer.model",
"*.tiktoken",
],
)
)
return model_path
def generate_step(
prompt: mx.array, model: nn.Module, temp: float = 0.0, return_probability: bool = False
) -> Generator[mx.array, None, None]:
"""
A generator producing text based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling. If temp is 0, use max sampling.
return_probability (bool): Whether to return the probability of generated token,
Yields:
Generator[mx.array]: A generator producing one token per call.
"""
def sample(logits: mx.array) -> Tuple[mx.array, float]:
prop = 1
if temp == 0:
token = mx.argmax(logits, axis=-1)
else:
token = mx.random.categorical(logits * (1 / temp))
if return_probability:
probs = mx.softmax(logits / temp)
prop = probs[0, token.item()]
return token, prop
y = prompt
cache = None
while True:
logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :]
y, t0 = sample(logits)
yield y, t0
def generate(
model: nn.Module,
tokenizer: PreTrainedTokenizer,
prompt: str,
temp: float = 0.0,
max_tokens: int = 100,
verbose: bool = False,
) -> str:
"""
Generate text from the model.
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt.
temp (float): The temperature for sampling (default 0).
max_tokens (int): The maximum number of tokens (default 100).
"""
prompt = mx.array(tokenizer.encode(prompt))
tokens = []
skip = 0
REPLACEMENT_CHAR = "\ufffd"
for token, _ in zip(generate_step(prompt, model, temp), range(max_tokens)):
if token == tokenizer.eos_token_id:
break
tokens.append(token.item())
if verbose:
s = tokenizer.decode(tokens)
if REPLACEMENT_CHAR not in s:
print(s[skip:], end="", flush=True)
skip = len(s)
tokens = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
if verbose:
print(tokens[skip:], flush=True)
return tokens
def load_model(model_path: Path) -> nn.Module:
"""
Load and initialize the model from a given path.
Args:
model_path (Path): The path to load the model from.
Returns:
nn.Module: The loaded and initialized model.
Raises:
FileNotFoundError: If the weight files (.safetensors) are not found.
ValueError: If the model class or args class are not found or cannot be instantiated.
"""
try:
with open(model_path / "config.json", "r") as f:
config = json.load(f)
quantization = config.get("quantization", None)
except FileNotFoundError:
logging.error(f"Config file not found in {model_path}")
raise
weight_files = glob.glob(str(model_path / "*.safetensors"))
if not weight_files:
logging.error(f"No safetensors found in {model_path}")
raise FileNotFoundError(f"No safetensors found in {model_path}")
weights = {}
for wf in weight_files:
weights.update(mx.load(wf))
model_class, model_args_class = _get_classes(config=config)
if hasattr(model_class, "sanitize"):
weights = model_class.sanitize(weights)
model_args = model_args_class.from_dict(config)
model = model_class(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(
model,
**quantization,
linear_class_predicate=linear_class_predicate,
)
model.load_weights(list(weights.items()))
mx.eval(model.parameters())
return model
def load(
path_or_hf_repo: str, tokenizer_config={}
) -> Tuple[nn.Module, PreTrainedTokenizer]:
"""
Load the model from a given path or a huggingface repository.
Args:
model_path (Path): The path or the huggingface repository to load the model from.
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
Defaults to an empty dictionary.
Returns:
nn.Module: The loaded model.
Raises:
FileNotFoundError: If config file or safetensors are not found.
ValueError: If model class or args class are not found.
"""
model_path = get_model_path(path_or_hf_repo)
model = load_model(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
return model, tokenizer