mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
385 lines
11 KiB
Python
385 lines
11 KiB
Python
import copy
|
|
import glob
|
|
import json
|
|
import logging
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, Callable, Dict, Generator, Tuple, Union
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
from huggingface_hub import snapshot_download
|
|
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
|
|
|
# Local imports
|
|
from .models import llama, mixtral, phi2, plamo, qwen, stablelm_epoch
|
|
from .tuner.utils import apply_lora_layers
|
|
|
|
# Constants
|
|
MODEL_MAPPING = {
|
|
"llama": llama,
|
|
"mistral": llama, # mistral is compatible with llama
|
|
"mixtral": mixtral,
|
|
"phi": phi2,
|
|
"stablelm_epoch": stablelm_epoch,
|
|
"qwen": qwen,
|
|
"plamo": plamo,
|
|
}
|
|
MAX_FILE_SIZE_GB = 5
|
|
|
|
linear_class_predicate = (
|
|
lambda m: isinstance(m, nn.Linear)
|
|
and m.weight.shape[0]
|
|
!= 8 # avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models
|
|
)
|
|
|
|
|
|
def _get_classes(config: dict):
|
|
"""
|
|
Retrieve the model and model args classes based on the configuration.
|
|
|
|
Args:
|
|
config (dict): The model configuration.
|
|
|
|
Returns:
|
|
A tuple containing the Model class and the ModelArgs class.
|
|
"""
|
|
model_type = config["model_type"]
|
|
if model_type not in MODEL_MAPPING:
|
|
msg = f"Model type {model_type} not supported."
|
|
logging.error(msg)
|
|
raise ValueError(msg)
|
|
|
|
arch = MODEL_MAPPING[model_type]
|
|
return arch.Model, arch.ModelArgs
|
|
|
|
|
|
def get_model_path(path_or_hf_repo: str) -> Path:
|
|
"""
|
|
Ensures the model is available locally. If the path does not exist locally,
|
|
it is downloaded from the Hugging Face Hub.
|
|
|
|
Args:
|
|
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
|
|
|
|
Returns:
|
|
Path: The path to the model.
|
|
"""
|
|
model_path = Path(path_or_hf_repo)
|
|
if not model_path.exists():
|
|
model_path = Path(
|
|
snapshot_download(
|
|
repo_id=path_or_hf_repo,
|
|
allow_patterns=[
|
|
"*.json",
|
|
"*.safetensors",
|
|
"*.py",
|
|
"tokenizer.model",
|
|
"*.tiktoken",
|
|
],
|
|
)
|
|
)
|
|
return model_path
|
|
|
|
|
|
def generate_step(
|
|
prompt: mx.array,
|
|
model: nn.Module,
|
|
temp: float = 0.0,
|
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
|
"""
|
|
A generator producing text based on the given prompt from the model.
|
|
|
|
Args:
|
|
prompt (mx.array): The input prompt.
|
|
model (nn.Module): The model to use for generation.
|
|
temp (float): The temperature for sampling, if 0 the argmax is used.
|
|
Yields:
|
|
Generator[Tuple[mx.array, mx.array]]: A generator producing
|
|
one token and probability per call.
|
|
"""
|
|
|
|
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
|
softmax_logits = mx.softmax(logits)
|
|
|
|
if temp == 0:
|
|
token = mx.argmax(logits, axis=-1)
|
|
else:
|
|
token = mx.random.categorical(logits * (1 / temp))
|
|
|
|
prob = softmax_logits[0, token]
|
|
return token, prob
|
|
|
|
y = prompt
|
|
cache = None
|
|
while True:
|
|
logits, cache = model(y[None], cache=cache)
|
|
logits = logits[:, -1, :]
|
|
y, prob = sample(logits)
|
|
yield y, prob
|
|
|
|
|
|
def generate(
|
|
model: nn.Module,
|
|
tokenizer: PreTrainedTokenizer,
|
|
prompt: str,
|
|
temp: float = 0.0,
|
|
max_tokens: int = 100,
|
|
verbose: bool = False,
|
|
formatter: Callable = None,
|
|
) -> str:
|
|
"""
|
|
Generate text from the model.
|
|
|
|
Args:
|
|
model (nn.Module): The language model.
|
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
|
prompt (str): The string prompt.
|
|
temp (float): The temperature for sampling (default 0).
|
|
max_tokens (int): The maximum number of tokens (default 100).
|
|
verbose (bool): If ``True``, print tokens and timing information
|
|
(default ``False``).
|
|
formatter (Optional[Callable]): A function which takes a token and a
|
|
probability and displays it.
|
|
"""
|
|
|
|
if verbose:
|
|
print("=" * 10)
|
|
print("Prompt:", prompt)
|
|
|
|
prompt = mx.array(tokenizer.encode(prompt))
|
|
|
|
tic = time.perf_counter()
|
|
tokens = []
|
|
skip = 0
|
|
REPLACEMENT_CHAR = "\ufffd"
|
|
|
|
for (token, prob), n in zip(generate_step(prompt, model, temp), range(max_tokens)):
|
|
if token == tokenizer.eos_token_id:
|
|
break
|
|
if n == 0:
|
|
prompt_time = time.perf_counter() - tic
|
|
tic = time.perf_counter()
|
|
tokens.append(token.item())
|
|
|
|
if verbose:
|
|
s = tokenizer.decode(tokens)
|
|
if formatter:
|
|
formatter(s[skip:], prob.item())
|
|
skip = len(s)
|
|
elif REPLACEMENT_CHAR not in s:
|
|
print(s[skip:], end="", flush=True)
|
|
skip = len(s)
|
|
|
|
tokens = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
|
|
|
|
if verbose:
|
|
print(tokens[skip:], flush=True)
|
|
gen_time = time.perf_counter() - tic
|
|
print("=" * 10)
|
|
if len(tokens) == 0:
|
|
print("No tokens generated for this prompt")
|
|
return
|
|
prompt_tps = prompt.size / prompt_time
|
|
gen_tps = (len(tokens) - 1) / gen_time
|
|
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
|
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
|
|
|
return tokens
|
|
|
|
|
|
def load_model(model_path: Path) -> nn.Module:
|
|
"""
|
|
Load and initialize the model from a given path.
|
|
|
|
Args:
|
|
model_path (Path): The path to load the model from.
|
|
|
|
Returns:
|
|
nn.Module: The loaded and initialized model.
|
|
|
|
Raises:
|
|
FileNotFoundError: If the weight files (.safetensors) are not found.
|
|
ValueError: If the model class or args class are not found or cannot be instantiated.
|
|
"""
|
|
try:
|
|
with open(model_path / "config.json", "r") as f:
|
|
config = json.load(f)
|
|
quantization = config.get("quantization", None)
|
|
except FileNotFoundError:
|
|
logging.error(f"Config file not found in {model_path}")
|
|
raise
|
|
|
|
weight_files = glob.glob(str(model_path / "*.safetensors"))
|
|
if not weight_files:
|
|
logging.error(f"No safetensors found in {model_path}")
|
|
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
|
|
|
weights = {}
|
|
for wf in weight_files:
|
|
weights.update(mx.load(wf))
|
|
|
|
model_class, model_args_class = _get_classes(config=config)
|
|
if hasattr(model_class, "sanitize"):
|
|
weights = model_class.sanitize(weights)
|
|
|
|
model_args = model_args_class.from_dict(config)
|
|
model = model_class(model_args)
|
|
|
|
if quantization is not None:
|
|
# for legacy models that don't have lm_head quant due to non-32 dims
|
|
if "lm_head.scales" not in weights.keys():
|
|
vocab_size = config["vocab_size"]
|
|
extended_linear_class_predicate = (
|
|
lambda layer: linear_class_predicate(layer)
|
|
and layer.weight.shape[0] != vocab_size
|
|
)
|
|
nn.QuantizedLinear.quantize_module(
|
|
model,
|
|
**quantization,
|
|
linear_class_predicate=extended_linear_class_predicate,
|
|
)
|
|
# for models that have lm_head quant
|
|
else:
|
|
nn.QuantizedLinear.quantize_module(
|
|
model,
|
|
**quantization,
|
|
linear_class_predicate=linear_class_predicate,
|
|
)
|
|
|
|
model.load_weights(list(weights.items()))
|
|
|
|
mx.eval(model.parameters())
|
|
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
def load(
|
|
path_or_hf_repo: str, tokenizer_config={}, adapter_file: str = None
|
|
) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
|
"""
|
|
Load the model and tokenizer from a given path or a huggingface repository.
|
|
|
|
Args:
|
|
model_path (Path): The path or the huggingface repository to load the model from.
|
|
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
|
Defaults to an empty dictionary.
|
|
adapter_file (str, optional): Path to the adapter file. If provided, applies LoRA layers to the model.
|
|
Defaults to None.
|
|
Returns:
|
|
Tuple[nn.Module, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer.
|
|
|
|
Raises:
|
|
FileNotFoundError: If config file or safetensors are not found.
|
|
ValueError: If model class or args class are not found.
|
|
"""
|
|
model_path = get_model_path(path_or_hf_repo)
|
|
|
|
model = load_model(model_path)
|
|
if adapter_file is not None:
|
|
model = apply_lora_layers(model, adapter_file)
|
|
model.eval()
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
|
|
return model, tokenizer
|
|
|
|
|
|
def fetch_from_hub(
|
|
model_path: Path,
|
|
) -> Tuple[Dict, dict, PreTrainedTokenizer]:
|
|
model = load_model(model_path)
|
|
|
|
config = AutoConfig.from_pretrained(model_path)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
|
|
return model, config.to_dict(), tokenizer
|
|
|
|
|
|
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
|
|
"""
|
|
Splits the weights into smaller shards.
|
|
|
|
Args:
|
|
weights (dict): Model weights.
|
|
max_file_size_gb (int): Maximum size of each shard in gigabytes.
|
|
|
|
Returns:
|
|
list: List of weight shards.
|
|
"""
|
|
max_file_size_bytes = max_file_size_gb << 30
|
|
shards = []
|
|
shard, shard_size = {}, 0
|
|
for k, v in weights.items():
|
|
if shard_size + v.nbytes > max_file_size_bytes:
|
|
shards.append(shard)
|
|
shard, shard_size = {}, 0
|
|
shard[k] = v
|
|
shard_size += v.nbytes
|
|
shards.append(shard)
|
|
return shards
|
|
|
|
|
|
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
|
"""
|
|
Uploads the model to Hugging Face hub.
|
|
|
|
Args:
|
|
path (str): Local path to the model.
|
|
upload_repo (str): Name of the HF repo to upload to.
|
|
hf_path (str): Path to the original Hugging Face model.
|
|
"""
|
|
import os
|
|
|
|
from huggingface_hub import HfApi, ModelCard, logging
|
|
|
|
card = ModelCard.load(hf_path)
|
|
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
|
|
card.text = f"""
|
|
# {upload_repo}
|
|
This model was converted to MLX format from [`{hf_path}`]().
|
|
Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
|
|
## Use with mlx
|
|
|
|
```bash
|
|
pip install mlx-lm
|
|
```
|
|
|
|
```python
|
|
from mlx_lm import load, generate
|
|
|
|
model, tokenizer = load("{upload_repo}")
|
|
response = generate(model, tokenizer, prompt="hello", verbose=True)
|
|
```
|
|
"""
|
|
card.save(os.path.join(path, "README.md"))
|
|
|
|
logging.set_verbosity_info()
|
|
|
|
api = HfApi()
|
|
api.create_repo(repo_id=upload_repo, exist_ok=True)
|
|
api.upload_folder(
|
|
folder_path=path,
|
|
repo_id=upload_repo,
|
|
repo_type="model",
|
|
)
|
|
|
|
|
|
def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
|
|
"""Save model weights into specified directory."""
|
|
if isinstance(save_path, str):
|
|
save_path = Path(save_path)
|
|
save_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
shards = make_shards(weights)
|
|
shards_count = len(shards)
|
|
shard_file_format = (
|
|
"model-{:05d}-of-{:05d}.safetensors"
|
|
if shards_count > 1
|
|
else "model.safetensors"
|
|
)
|
|
|
|
for i, shard in enumerate(shards):
|
|
shard_name = shard_file_format.format(i + 1, shards_count)
|
|
mx.save_safetensors(str(save_path / shard_name), shard)
|