mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
893 lines
29 KiB
Python
893 lines
29 KiB
Python
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
import contextlib
|
|
import copy
|
|
import glob
|
|
import importlib
|
|
import json
|
|
import logging
|
|
import shutil
|
|
import time
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from textwrap import dedent
|
|
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
from huggingface_hub import snapshot_download
|
|
from mlx.utils import tree_flatten, tree_reduce
|
|
from transformers import PreTrainedTokenizer
|
|
|
|
# Local imports
|
|
from .models import cache
|
|
from .sample_utils import make_logits_processors, make_sampler
|
|
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
|
from .tuner.utils import dequantize as dequantize_model
|
|
from .tuner.utils import load_adapters, nparams
|
|
|
|
# Constants
|
|
MODEL_REMAPPING = {
|
|
"mistral": "llama", # mistral is compatible with llama
|
|
"phi-msft": "phixtral",
|
|
"falcon_mamba": "mamba",
|
|
}
|
|
|
|
MAX_FILE_SIZE_GB = 5
|
|
|
|
# A stream on the default device just for generation
|
|
generation_stream = mx.new_stream(mx.default_device())
|
|
|
|
|
|
class ModelNotFoundError(Exception):
|
|
def __init__(self, message):
|
|
self.message = message
|
|
super().__init__(self.message)
|
|
|
|
|
|
@dataclass
|
|
class GenerationResponse:
|
|
"""
|
|
The output of :func:`stream_generate`.
|
|
|
|
Args:
|
|
text (str): The next segment of decoded text. This can be an empty string.
|
|
token (int): The next token.
|
|
logprobs (mx.array): A vector of log probabilities.
|
|
prompt_tokens (int): The number of tokens in the prompt.
|
|
prompt_tps (float): The prompt processing tokens-per-second.
|
|
generation_tokens (int): The number of generated tokens.
|
|
generation_tps (float): The tokens-per-second for generation.
|
|
peak_memory (float): The peak memory used so far in GB.
|
|
finish_reason (str): The reason the response is being sent: "length", "stop" or `None`
|
|
"""
|
|
|
|
text: str
|
|
token: int
|
|
logprobs: mx.array
|
|
prompt_tokens: int
|
|
prompt_tps: float
|
|
generation_tokens: int
|
|
generation_tps: float
|
|
peak_memory: float
|
|
finish_reason: Optional[str] = None
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
|
|
"""
|
|
A context manager to temporarily change the wired limit.
|
|
|
|
Note, the wired limit should not be changed during an async eval. If an
|
|
async eval could be running pass in the streams to synchronize with prior
|
|
to exiting the context manager.
|
|
"""
|
|
model_bytes = tree_reduce(
|
|
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
|
|
)
|
|
max_rec_size = mx.metal.device_info()["max_recommended_working_set_size"]
|
|
if model_bytes > 0.9 * max_rec_size:
|
|
model_mb = model_bytes // 2**20
|
|
max_rec_mb = max_rec_size // 2**20
|
|
print(
|
|
f"[WARNING] Generating with a model that requires {model_mb} MB "
|
|
f"which is close to the maximum recommended size of {max_rec_mb} "
|
|
"MB. This can be slow. See the documentation for possible work-arounds: "
|
|
"https://github.com/ml-explore/mlx-examples/tree/main/llms#large-models"
|
|
)
|
|
old_limit = mx.metal.set_wired_limit(max_rec_size)
|
|
try:
|
|
yield None
|
|
finally:
|
|
if streams is not None:
|
|
for s in streams:
|
|
mx.synchronize(s)
|
|
else:
|
|
mx.synchronize()
|
|
mx.metal.set_wired_limit(old_limit)
|
|
|
|
|
|
def _get_classes(config: dict):
|
|
"""
|
|
Retrieve the model and model args classes based on the configuration.
|
|
|
|
Args:
|
|
config (dict): The model configuration.
|
|
|
|
Returns:
|
|
A tuple containing the Model class and the ModelArgs class.
|
|
"""
|
|
model_type = config["model_type"]
|
|
model_type = MODEL_REMAPPING.get(model_type, model_type)
|
|
try:
|
|
arch = importlib.import_module(f"mlx_lm.models.{model_type}")
|
|
except ImportError:
|
|
msg = f"Model type {model_type} not supported."
|
|
logging.error(msg)
|
|
raise ValueError(msg)
|
|
|
|
return arch.Model, arch.ModelArgs
|
|
|
|
|
|
def compute_bits_per_weight(model):
|
|
model_bytes = tree_reduce(
|
|
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
|
|
)
|
|
leaf_modules = tree_flatten(
|
|
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
|
|
)
|
|
model_params = sum(nparams(m) for _, m in leaf_modules)
|
|
return model_bytes * 8 / model_params
|
|
|
|
|
|
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
|
|
"""
|
|
Ensures the model is available locally. If the path does not exist locally,
|
|
it is downloaded from the Hugging Face Hub.
|
|
|
|
Args:
|
|
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
|
|
revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
|
|
|
|
Returns:
|
|
Path: The path to the model.
|
|
"""
|
|
model_path = Path(path_or_hf_repo)
|
|
if not model_path.exists():
|
|
try:
|
|
model_path = Path(
|
|
snapshot_download(
|
|
repo_id=path_or_hf_repo,
|
|
revision=revision,
|
|
allow_patterns=[
|
|
"*.json",
|
|
"*.safetensors",
|
|
"*.py",
|
|
"tokenizer.model",
|
|
"*.tiktoken",
|
|
"*.txt",
|
|
],
|
|
)
|
|
)
|
|
except:
|
|
raise ModelNotFoundError(
|
|
f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
|
|
"Please make sure you specified the local path or Hugging Face"
|
|
" repo id correctly.\nIf you are trying to access a private or"
|
|
" gated Hugging Face repo, make sure you are authenticated:\n"
|
|
"https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login"
|
|
) from None
|
|
return model_path
|
|
|
|
|
|
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
|
|
if (
|
|
kv_bits is not None
|
|
and not isinstance(prompt_cache[0], cache.QuantizedKVCache)
|
|
and prompt_cache[0].offset > quantized_kv_start
|
|
):
|
|
for i in range(len(prompt_cache)):
|
|
if isinstance(prompt_cache[i], cache.KVCache):
|
|
prompt_cache[i] = prompt_cache[i].to_quantized(
|
|
group_size=kv_group_size, bits=kv_bits
|
|
)
|
|
|
|
|
|
def generate_step(
|
|
prompt: mx.array,
|
|
model: nn.Module,
|
|
*,
|
|
max_tokens: int = 256,
|
|
sampler: Optional[Callable[mx.array, mx.array]] = None,
|
|
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
|
max_kv_size: Optional[int] = None,
|
|
prompt_cache: Optional[Any] = None,
|
|
prefill_step_size: int = 512,
|
|
kv_bits: Optional[int] = None,
|
|
kv_group_size: int = 64,
|
|
quantized_kv_start: int = 0,
|
|
prompt_progress_callback: Optional[Callable[int, int]] = None,
|
|
temp: Optional[float] = None,
|
|
repetition_penalty: Optional[float] = None,
|
|
repetition_context_size: Optional[int] = None,
|
|
top_p: Optional[float] = None,
|
|
min_p: Optional[float] = None,
|
|
min_tokens_to_keep: Optional[int] = None,
|
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
|
"""
|
|
A generator producing token ids based on the given prompt from the model.
|
|
|
|
Args:
|
|
prompt (mx.array): The input prompt.
|
|
model (nn.Module): The model to use for generation.
|
|
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
|
|
generator. Default: ``256``.
|
|
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
|
|
token from a vector of log probabilities. Default: ``None``.
|
|
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
|
A list of functions that take tokens and logits and return the processed
|
|
logits. Default: ``None``.
|
|
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
|
entries (except the first 4 tokens) will be overwritten.
|
|
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
|
provided, the cache will be updated in place.
|
|
prefill_step_size (int): Step size for processing the prompt.
|
|
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
|
None implies no cache quantization. Default: ``None``.
|
|
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
|
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
|
when ``kv_bits`` is non-None. Default: ``0``.
|
|
prompt_prorgress_callback (Callable[int, int]): A call-back which takes the
|
|
prompt tokens processed so far and the total number of prompt tokens.
|
|
|
|
Yields:
|
|
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
|
"""
|
|
|
|
y = prompt
|
|
tokens = None
|
|
|
|
# Create the KV cache for generation
|
|
if prompt_cache is None:
|
|
prompt_cache = cache.make_prompt_cache(
|
|
model,
|
|
max_kv_size=max_kv_size,
|
|
)
|
|
elif len(prompt_cache) != len(model.layers):
|
|
raise ValueError("Wrong number of layers in the prompt cache.")
|
|
|
|
if temp is not None or top_p is not None or min_tokens_to_keep is not None:
|
|
print(
|
|
"[Warning] Specifying sampling arguments to ``generate_step`` is "
|
|
"deprecated. Pass in a ``sampler`` instead."
|
|
)
|
|
if repetition_penalty is not None:
|
|
print(
|
|
"[Warning] Specifying ``repetition_penalty`` is deprecated. "
|
|
"Pass in ``logits_processors`` instead."
|
|
)
|
|
|
|
sampler = sampler or make_sampler(
|
|
temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1
|
|
)
|
|
logits_processors = logits_processors or make_logits_processors(
|
|
None, repetition_penalty, repetition_context_size or 20
|
|
)
|
|
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
|
|
|
|
def _step(y):
|
|
with mx.stream(generation_stream):
|
|
logits = model(y[None], cache=prompt_cache)
|
|
logits = logits[:, -1, :]
|
|
|
|
if logits_processors:
|
|
nonlocal tokens
|
|
tokens = mx.concat([tokens, y]) if tokens is not None else y
|
|
|
|
for processor in logits_processors:
|
|
logits = processor(tokens, logits)
|
|
|
|
maybe_quantize_kv_cache(
|
|
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
|
)
|
|
|
|
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
|
y = sampler(logprobs)
|
|
return y, logprobs.squeeze(0)
|
|
|
|
with mx.stream(generation_stream):
|
|
total_prompt_tokens = y.size
|
|
prompt_processed_tokens = 0
|
|
while y.size > prefill_step_size:
|
|
model(y[:prefill_step_size][None], cache=prompt_cache)
|
|
maybe_quantize_kv_cache(
|
|
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
|
)
|
|
mx.eval([c.state for c in prompt_cache])
|
|
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
|
|
prompt_processed_tokens += prefill_step_size
|
|
y = y[prefill_step_size:]
|
|
mx.metal.clear_cache()
|
|
|
|
y, logprobs = _step(y)
|
|
|
|
mx.async_eval(y, logprobs)
|
|
n = 0
|
|
while True:
|
|
if n != max_tokens:
|
|
next_y, next_logprobs = _step(y)
|
|
mx.async_eval(next_y, next_logprobs)
|
|
if n == 0:
|
|
mx.eval(y)
|
|
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
|
if n == max_tokens:
|
|
break
|
|
yield y.item(), logprobs
|
|
if n % 256 == 0:
|
|
mx.metal.clear_cache()
|
|
y, logprobs = next_y, next_logprobs
|
|
n += 1
|
|
|
|
|
|
def stream_generate(
|
|
model: nn.Module,
|
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
|
prompt: Union[str, mx.array, List[int]],
|
|
**kwargs,
|
|
) -> Generator[GenerationResponse, None, None]:
|
|
"""
|
|
A generator producing text based on the given prompt from the model.
|
|
|
|
Args:
|
|
model (nn.Module): The model to use for generation.
|
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
|
prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens.
|
|
kwargs: The remaining options get passed to :func:`generate_step`.
|
|
See :func:`generate_step` for more details.
|
|
|
|
Yields:
|
|
GenerationResponse: An instance containing the generated text segment and
|
|
associated metadata. See :class:`GenerationResponse` for details.
|
|
"""
|
|
if not isinstance(tokenizer, TokenizerWrapper):
|
|
tokenizer = TokenizerWrapper(tokenizer)
|
|
|
|
if not isinstance(prompt, mx.array):
|
|
prompt = mx.array(
|
|
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
|
|
)
|
|
|
|
detokenizer = tokenizer.detokenizer
|
|
|
|
with wired_limit(model, [generation_stream]):
|
|
detokenizer.reset()
|
|
tic = time.perf_counter()
|
|
for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)):
|
|
if n == 0:
|
|
prompt_time = time.perf_counter() - tic
|
|
prompt_tps = prompt.size / prompt_time
|
|
tic = time.perf_counter()
|
|
if token in tokenizer.eos_token_ids:
|
|
break
|
|
|
|
detokenizer.add_token(token)
|
|
|
|
yield GenerationResponse(
|
|
text=detokenizer.last_segment,
|
|
token=token,
|
|
logprobs=logprobs,
|
|
prompt_tokens=prompt.size,
|
|
prompt_tps=prompt_tps,
|
|
generation_tokens=n + 1,
|
|
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
|
peak_memory=mx.metal.get_peak_memory() / 1e9,
|
|
finish_reason=None,
|
|
)
|
|
|
|
detokenizer.finalize()
|
|
yield GenerationResponse(
|
|
text=detokenizer.last_segment,
|
|
token=token,
|
|
logprobs=logprobs,
|
|
prompt_tokens=prompt.size,
|
|
prompt_tps=prompt_tps,
|
|
generation_tokens=n + 1,
|
|
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
|
peak_memory=mx.metal.get_peak_memory() / 1e9,
|
|
finish_reason="stop" if token in tokenizer.eos_token_ids else "length",
|
|
)
|
|
|
|
|
|
def generate(
|
|
model: nn.Module,
|
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
|
prompt: str,
|
|
verbose: bool = False,
|
|
formatter: Optional[Callable] = None,
|
|
**kwargs,
|
|
) -> str:
|
|
"""
|
|
Generate a complete response from the model.
|
|
|
|
Args:
|
|
model (nn.Module): The language model.
|
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
|
prompt (str): The string prompt.
|
|
verbose (bool): If ``True``, print tokens and timing information.
|
|
Default: ``False``.
|
|
kwargs: The remaining options get passed to :func:`stream_generate`.
|
|
See :func:`stream_generate` for more details.
|
|
"""
|
|
if formatter is not None:
|
|
print(
|
|
"[Warning] Text formatting is deprecated and no longer used. "
|
|
"The argument will be removed in a future version."
|
|
)
|
|
if verbose:
|
|
print("=" * 10)
|
|
print("Prompt:", prompt)
|
|
|
|
text = ""
|
|
for response in stream_generate(model, tokenizer, prompt, **kwargs):
|
|
if verbose:
|
|
print(response.text, end="", flush=True)
|
|
text += response.text
|
|
|
|
if verbose:
|
|
print()
|
|
print("=" * 10)
|
|
if len(text) == 0:
|
|
print("No text generated for this prompt")
|
|
return
|
|
print(
|
|
f"Prompt: {response.prompt_tokens} tokens, "
|
|
f"{response.prompt_tps:.3f} tokens-per-sec"
|
|
)
|
|
print(
|
|
f"Generation: {response.generation_tokens} tokens, "
|
|
f"{response.generation_tps:.3f} tokens-per-sec"
|
|
)
|
|
print(f"Peak memory: {response.peak_memory:.3f} GB")
|
|
return text
|
|
|
|
|
|
def load_config(model_path: Path) -> dict:
|
|
try:
|
|
with open(model_path / "config.json", "r") as f:
|
|
config = json.load(f)
|
|
except FileNotFoundError:
|
|
logging.error(f"Config file not found in {model_path}")
|
|
raise
|
|
return config
|
|
|
|
|
|
def load_model(
|
|
model_path: Path,
|
|
lazy: bool = False,
|
|
model_config: dict = {},
|
|
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
|
|
) -> nn.Module:
|
|
"""
|
|
Load and initialize the model from a given path.
|
|
|
|
Args:
|
|
model_path (Path): The path to load the model from.
|
|
lazy (bool): If False eval the model parameters to make sure they are
|
|
loaded in memory before returning, otherwise they will be loaded
|
|
when needed. Default: ``False``
|
|
model_config (dict, optional): Optional configuration parameters for the
|
|
model. Defaults to an empty dictionary.
|
|
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
|
|
A function that returns the model class and model args class given a config.
|
|
Defaults to the ``_get_classes`` function.
|
|
|
|
Returns:
|
|
nn.Module: The loaded and initialized model.
|
|
|
|
Raises:
|
|
FileNotFoundError: If the weight files (.safetensors) are not found.
|
|
ValueError: If the model class or args class are not found or cannot be instantiated.
|
|
"""
|
|
config = load_config(model_path)
|
|
config.update(model_config)
|
|
|
|
weight_files = glob.glob(str(model_path / "model*.safetensors"))
|
|
|
|
if not weight_files:
|
|
# Try weight for back-compat
|
|
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
|
|
|
|
if not weight_files:
|
|
logging.error(f"No safetensors found in {model_path}")
|
|
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
|
|
|
weights = {}
|
|
for wf in weight_files:
|
|
weights.update(mx.load(wf))
|
|
|
|
model_class, model_args_class = get_model_classes(config=config)
|
|
|
|
model_args = model_args_class.from_dict(config)
|
|
model = model_class(model_args)
|
|
|
|
if hasattr(model, "sanitize"):
|
|
weights = model.sanitize(weights)
|
|
|
|
if (quantization := config.get("quantization", None)) is not None:
|
|
|
|
def class_predicate(p, m):
|
|
# Handle custom per layer quantizations
|
|
if p in config["quantization"]:
|
|
return config["quantization"][p]
|
|
if not hasattr(m, "to_quantized"):
|
|
return False
|
|
# Handle legacy models which may not have everything quantized
|
|
return f"{p}.scales" in weights
|
|
|
|
nn.quantize(
|
|
model,
|
|
group_size=quantization["group_size"],
|
|
bits=quantization["bits"],
|
|
quantization_type=quantization["quantization_type"],
|
|
class_predicate=class_predicate,
|
|
)
|
|
|
|
model.load_weights(list(weights.items()))
|
|
|
|
if not lazy:
|
|
mx.eval(model.parameters())
|
|
|
|
model.eval()
|
|
return model, config
|
|
|
|
|
|
def load(
|
|
path_or_hf_repo: str,
|
|
tokenizer_config={},
|
|
model_config={},
|
|
adapter_path: Optional[str] = None,
|
|
lazy: bool = False,
|
|
) -> Tuple[nn.Module, TokenizerWrapper]:
|
|
"""
|
|
Load the model and tokenizer from a given path or a huggingface repository.
|
|
|
|
Args:
|
|
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
|
|
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
|
Defaults to an empty dictionary.
|
|
model_config(dict, optional): Configuration parameters specifically for the model.
|
|
Defaults to an empty dictionary.
|
|
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
|
|
to the model. Default: ``None``.
|
|
lazy (bool): If False eval the model parameters to make sure they are
|
|
loaded in memory before returning, otherwise they will be loaded
|
|
when needed. Default: ``False``
|
|
Returns:
|
|
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
|
|
|
|
Raises:
|
|
FileNotFoundError: If config file or safetensors are not found.
|
|
ValueError: If model class or args class are not found.
|
|
"""
|
|
model_path = get_model_path(path_or_hf_repo)
|
|
|
|
model, config = load_model(model_path, lazy)
|
|
if adapter_path is not None:
|
|
model = load_adapters(model, adapter_path)
|
|
model.eval()
|
|
tokenizer = load_tokenizer(
|
|
model_path, tokenizer_config, eos_token_ids=config.get("eos_token_id", None)
|
|
)
|
|
|
|
return model, tokenizer
|
|
|
|
|
|
def fetch_from_hub(
|
|
model_path: Path, lazy: bool = False
|
|
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
|
model, config = load_model(model_path, lazy)
|
|
tokenizer = load_tokenizer(
|
|
model_path, eos_token_ids=config.get("eos_token_id", None)
|
|
)
|
|
return model, config, tokenizer
|
|
|
|
|
|
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
|
|
"""
|
|
Splits the weights into smaller shards.
|
|
|
|
Args:
|
|
weights (dict): Model weights.
|
|
max_file_size_gb (int): Maximum size of each shard in gigabytes.
|
|
|
|
Returns:
|
|
list: List of weight shards.
|
|
"""
|
|
max_file_size_bytes = max_file_size_gb << 30
|
|
shards = []
|
|
shard, shard_size = {}, 0
|
|
for k, v in weights.items():
|
|
if shard_size + v.nbytes > max_file_size_bytes:
|
|
shards.append(shard)
|
|
shard, shard_size = {}, 0
|
|
shard[k] = v
|
|
shard_size += v.nbytes
|
|
shards.append(shard)
|
|
return shards
|
|
|
|
|
|
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
|
"""
|
|
Uploads the model to Hugging Face hub.
|
|
|
|
Args:
|
|
path (str): Local path to the model.
|
|
upload_repo (str): Name of the HF repo to upload to.
|
|
hf_path (str): Path to the original Hugging Face model.
|
|
"""
|
|
import os
|
|
|
|
from huggingface_hub import HfApi, ModelCard, logging
|
|
|
|
from . import __version__
|
|
|
|
card = ModelCard.load(hf_path)
|
|
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
|
|
card.data.base_model = hf_path
|
|
card.text = dedent(
|
|
f"""
|
|
# {upload_repo}
|
|
|
|
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was
|
|
converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path})
|
|
using mlx-lm version **{__version__}**.
|
|
|
|
## Use with mlx
|
|
|
|
```bash
|
|
pip install mlx-lm
|
|
```
|
|
|
|
```python
|
|
from mlx_lm import load, generate
|
|
|
|
model, tokenizer = load("{upload_repo}")
|
|
|
|
prompt="hello"
|
|
|
|
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
|
|
messages = [{{"role": "user", "content": prompt}}]
|
|
prompt = tokenizer.apply_chat_template(
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
)
|
|
|
|
response = generate(model, tokenizer, prompt=prompt, verbose=True)
|
|
```
|
|
"""
|
|
)
|
|
card.save(os.path.join(path, "README.md"))
|
|
|
|
logging.set_verbosity_info()
|
|
|
|
api = HfApi()
|
|
api.create_repo(repo_id=upload_repo, exist_ok=True)
|
|
api.upload_folder(
|
|
folder_path=path,
|
|
repo_id=upload_repo,
|
|
repo_type="model",
|
|
multi_commits=True,
|
|
multi_commits_verbose=True,
|
|
)
|
|
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
|
|
|
|
|
|
def save_weights(
|
|
save_path: Union[str, Path],
|
|
weights: Dict[str, Any],
|
|
*,
|
|
donate_weights: bool = False,
|
|
) -> None:
|
|
"""Save model weights into specified directory."""
|
|
if isinstance(save_path, str):
|
|
save_path = Path(save_path)
|
|
save_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
shards = make_shards(weights)
|
|
shards_count = len(shards)
|
|
shard_file_format = (
|
|
"model-{:05d}-of-{:05d}.safetensors"
|
|
if shards_count > 1
|
|
else "model.safetensors"
|
|
)
|
|
|
|
total_size = sum(v.nbytes for v in weights.values())
|
|
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
|
|
|
# Write the weights and make sure no references are kept other than the
|
|
# necessary ones
|
|
if donate_weights:
|
|
weights.clear()
|
|
del weights
|
|
|
|
for i in range(len(shards)):
|
|
shard = shards[i]
|
|
shards[i] = None
|
|
shard_name = shard_file_format.format(i + 1, shards_count)
|
|
shard_path = save_path / shard_name
|
|
|
|
mx.save_safetensors(str(shard_path), shard, metadata={"format": "mlx"})
|
|
|
|
for weight_name in shard.keys():
|
|
index_data["weight_map"][weight_name] = shard_name
|
|
del shard
|
|
|
|
index_data["weight_map"] = {
|
|
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
|
|
}
|
|
|
|
with open(save_path / "model.safetensors.index.json", "w") as f:
|
|
json.dump(
|
|
index_data,
|
|
f,
|
|
indent=4,
|
|
)
|
|
|
|
|
|
def quantize_model(
|
|
model: nn.Module,
|
|
config: dict,
|
|
q_group_size: int,
|
|
q_bits: int,
|
|
q_type: str,
|
|
quant_predicate: Optional[
|
|
Callable[[str, nn.Module, dict], Union[bool, dict]]
|
|
] = None,
|
|
) -> Tuple:
|
|
"""
|
|
Applies quantization to the model weights.
|
|
|
|
Args:
|
|
model (nn.Module): The model to be quantized.
|
|
config (dict): Model configuration.
|
|
q_group_size (int): Group size for quantization.
|
|
q_bits (int): Bits per weight for quantization.
|
|
q_type (str): Quantization type
|
|
quant_predicate (Callable): A callable that decides how
|
|
to quantize each layer based on the path.
|
|
Accepts the layer `path`, the `module` and the model `config`.
|
|
Returns either a bool to signify quantize/no quantize or
|
|
a dict of quantization parameters to pass to `to_quantized`.
|
|
|
|
Returns:
|
|
Tuple: Tuple containing quantized weights and config.
|
|
"""
|
|
quantized_config = copy.deepcopy(config)
|
|
quantized_config["quantization"] = {
|
|
"group_size": q_group_size,
|
|
"bits": q_bits,
|
|
"quantization_type": q_type,
|
|
}
|
|
|
|
# Add any custom quantization parameters to the config as we go
|
|
def _class_predicate(p, m):
|
|
if quant_predicate:
|
|
bool_or_params = quant_predicate(p, m, config)
|
|
else:
|
|
if isinstance(m, nn.Embedding):
|
|
bool_or_params = {
|
|
"group_size": q_group_size,
|
|
"bits": q_bits,
|
|
"quantization_type": "affine",
|
|
}
|
|
else:
|
|
bool_or_params = hasattr(m, "to_quantized")
|
|
quantized_config["quantization"][p] = bool_or_params
|
|
return bool_or_params
|
|
|
|
nn.quantize(
|
|
model,
|
|
q_group_size,
|
|
q_bits,
|
|
quantization_type=q_type,
|
|
class_predicate=_class_predicate,
|
|
)
|
|
# support hf model tree #957
|
|
quantized_config["quantization_config"] = quantized_config["quantization"]
|
|
quantized_weights = dict(tree_flatten(model.parameters()))
|
|
|
|
bpw = compute_bits_per_weight(model)
|
|
print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.")
|
|
|
|
return quantized_weights, quantized_config
|
|
|
|
|
|
def save_config(
|
|
config: dict,
|
|
config_path: Union[str, Path],
|
|
) -> None:
|
|
"""Save the model configuration to the ``config_path``.
|
|
|
|
The final configuration will be sorted before saving for better readability.
|
|
|
|
Args:
|
|
config (dict): The model configuration.
|
|
config_path (Union[str, Path]): Model configuration file path.
|
|
"""
|
|
# Clean unused keys
|
|
config.pop("_name_or_path", None)
|
|
|
|
# sort the config for better readability
|
|
config = dict(sorted(config.items()))
|
|
|
|
# write the updated config to the config_path (if provided)
|
|
with open(config_path, "w") as fid:
|
|
json.dump(config, fid, indent=4)
|
|
|
|
|
|
def convert(
|
|
hf_path: str,
|
|
mlx_path: str = "mlx_model",
|
|
quantize: bool = False,
|
|
q_group_size: int = 64,
|
|
q_bits: int = 4,
|
|
q_type: str = "affine",
|
|
dtype: str = "float16",
|
|
upload_repo: str = None,
|
|
revision: Optional[str] = None,
|
|
dequantize: bool = False,
|
|
quant_predicate: Optional[
|
|
Callable[[str, nn.Module, dict], Union[bool, dict]]
|
|
] = None,
|
|
):
|
|
# Check the save path is empty
|
|
if isinstance(mlx_path, str):
|
|
mlx_path = Path(mlx_path)
|
|
|
|
if mlx_path.exists():
|
|
raise ValueError(
|
|
f"Cannot save to the path {mlx_path} as it already exists."
|
|
" Please delete the file/directory or specify a new path to save to."
|
|
)
|
|
|
|
print("[INFO] Loading")
|
|
model_path = get_model_path(hf_path, revision=revision)
|
|
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
|
|
|
weights = dict(tree_flatten(model.parameters()))
|
|
dtype = getattr(mx, dtype)
|
|
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
|
|
|
if quantize and dequantize:
|
|
raise ValueError("Choose either quantize or dequantize, not both.")
|
|
|
|
if quantize:
|
|
print("[INFO] Quantizing")
|
|
model.load_weights(list(weights.items()))
|
|
weights, config = quantize_model(
|
|
model,
|
|
config,
|
|
q_group_size,
|
|
q_bits,
|
|
q_type=q_type,
|
|
quant_predicate=quant_predicate,
|
|
)
|
|
|
|
if dequantize:
|
|
print("[INFO] Dequantizing")
|
|
model = dequantize_model(model)
|
|
weights = dict(tree_flatten(model.parameters()))
|
|
|
|
del model
|
|
save_weights(mlx_path, weights, donate_weights=True)
|
|
|
|
py_files = glob.glob(str(model_path / "*.py"))
|
|
for file in py_files:
|
|
shutil.copy(file, mlx_path)
|
|
|
|
tokenizer.save_pretrained(mlx_path)
|
|
|
|
save_config(config, config_path=mlx_path / "config.json")
|
|
|
|
if upload_repo is not None:
|
|
upload_to_hub(mlx_path, upload_repo, hf_path)
|