2024-02-28 00:47:56 +08:00
|
|
|
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
|
2024-10-31 23:17:14 +08:00
|
|
|
import contextlib
|
2024-01-24 00:44:37 +08:00
|
|
|
import copy
|
2025-01-11 07:27:08 +08:00
|
|
|
import functools
|
2024-01-12 04:29:12 +08:00
|
|
|
import glob
|
2024-02-13 02:51:02 +08:00
|
|
|
import importlib
|
2024-01-12 04:29:12 +08:00
|
|
|
import json
|
|
|
|
import logging
|
2025-01-11 07:29:34 +08:00
|
|
|
import os
|
2024-02-28 00:47:56 +08:00
|
|
|
import shutil
|
2024-01-24 04:44:23 +08:00
|
|
|
import time
|
2024-11-24 03:47:06 +08:00
|
|
|
from dataclasses import dataclass
|
2024-01-12 04:29:12 +08:00
|
|
|
from pathlib import Path
|
2024-03-01 14:23:01 +08:00
|
|
|
from textwrap import dedent
|
2024-08-29 13:11:45 +08:00
|
|
|
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
import mlx.nn as nn
|
2025-01-11 07:29:34 +08:00
|
|
|
|
|
|
|
if os.getenv("MLXLM_USE_MODELSCOPE", "False").lower() == "true":
|
|
|
|
try:
|
|
|
|
from modelscope import snapshot_download
|
|
|
|
except ImportError:
|
|
|
|
raise ImportError(
|
|
|
|
"Please run `pip install modelscope` to activate the ModelScope."
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
from huggingface_hub import snapshot_download
|
|
|
|
|
2024-12-13 02:37:40 +08:00
|
|
|
from mlx.utils import tree_flatten, tree_reduce
|
2024-05-24 10:47:35 +08:00
|
|
|
from transformers import PreTrainedTokenizer
|
2024-01-12 04:29:12 +08:00
|
|
|
|
2024-01-13 02:25:56 +08:00
|
|
|
# Local imports
|
2024-11-01 07:59:52 +08:00
|
|
|
from .models import cache
|
2024-11-08 08:15:24 +08:00
|
|
|
from .sample_utils import make_logits_processors, make_sampler
|
2024-04-09 13:36:01 +08:00
|
|
|
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
2024-03-20 10:50:08 +08:00
|
|
|
from .tuner.utils import dequantize as dequantize_model
|
2024-12-09 06:21:50 +08:00
|
|
|
from .tuner.utils import load_adapters, nparams
|
2024-01-13 02:25:56 +08:00
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
# Constants
|
2024-02-13 02:51:02 +08:00
|
|
|
MODEL_REMAPPING = {
|
|
|
|
"mistral": "llama", # mistral is compatible with llama
|
|
|
|
"phi-msft": "phixtral",
|
2024-11-05 04:23:30 +08:00
|
|
|
"falcon_mamba": "mamba",
|
2024-01-12 04:29:12 +08:00
|
|
|
}
|
2024-02-13 02:51:02 +08:00
|
|
|
|
2024-01-26 10:59:32 +08:00
|
|
|
MAX_FILE_SIZE_GB = 5
|
2024-01-12 04:29:12 +08:00
|
|
|
|
2024-11-08 08:15:24 +08:00
|
|
|
# A stream on the default device just for generation
|
|
|
|
generation_stream = mx.new_stream(mx.default_device())
|
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
|
2024-05-20 21:39:05 +08:00
|
|
|
class ModelNotFoundError(Exception):
|
|
|
|
def __init__(self, message):
|
|
|
|
self.message = message
|
|
|
|
super().__init__(self.message)
|
|
|
|
|
|
|
|
|
2024-11-24 03:47:06 +08:00
|
|
|
@dataclass
|
|
|
|
class GenerationResponse:
|
|
|
|
"""
|
|
|
|
The output of :func:`stream_generate`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
text (str): The next segment of decoded text. This can be an empty string.
|
|
|
|
token (int): The next token.
|
|
|
|
logprobs (mx.array): A vector of log probabilities.
|
|
|
|
prompt_tokens (int): The number of tokens in the prompt.
|
|
|
|
prompt_tps (float): The prompt processing tokens-per-second.
|
|
|
|
generation_tokens (int): The number of generated tokens.
|
|
|
|
generation_tps (float): The tokens-per-second for generation.
|
|
|
|
peak_memory (float): The peak memory used so far in GB.
|
2024-12-13 02:37:40 +08:00
|
|
|
finish_reason (str): The reason the response is being sent: "length", "stop" or `None`
|
2024-11-24 03:47:06 +08:00
|
|
|
"""
|
|
|
|
|
|
|
|
text: str
|
|
|
|
token: int
|
|
|
|
logprobs: mx.array
|
|
|
|
prompt_tokens: int
|
|
|
|
prompt_tps: float
|
|
|
|
generation_tokens: int
|
|
|
|
generation_tps: float
|
|
|
|
peak_memory: float
|
2024-12-13 02:37:40 +08:00
|
|
|
finish_reason: Optional[str] = None
|
2024-11-24 03:47:06 +08:00
|
|
|
|
|
|
|
|
2024-10-31 23:17:14 +08:00
|
|
|
@contextlib.contextmanager
|
|
|
|
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
|
|
|
|
"""
|
|
|
|
A context manager to temporarily change the wired limit.
|
|
|
|
|
|
|
|
Note, the wired limit should not be changed during an async eval. If an
|
|
|
|
async eval could be running pass in the streams to synchronize with prior
|
|
|
|
to exiting the context manager.
|
|
|
|
"""
|
|
|
|
model_bytes = tree_reduce(
|
|
|
|
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
|
|
|
|
)
|
|
|
|
max_rec_size = mx.metal.device_info()["max_recommended_working_set_size"]
|
|
|
|
if model_bytes > 0.9 * max_rec_size:
|
|
|
|
model_mb = model_bytes // 2**20
|
|
|
|
max_rec_mb = max_rec_size // 2**20
|
|
|
|
print(
|
2024-11-13 22:14:03 +08:00
|
|
|
f"[WARNING] Generating with a model that requires {model_mb} MB "
|
|
|
|
f"which is close to the maximum recommended size of {max_rec_mb} "
|
2024-10-31 23:17:14 +08:00
|
|
|
"MB. This can be slow. See the documentation for possible work-arounds: "
|
|
|
|
"https://github.com/ml-explore/mlx-examples/tree/main/llms#large-models"
|
|
|
|
)
|
|
|
|
old_limit = mx.metal.set_wired_limit(max_rec_size)
|
|
|
|
try:
|
|
|
|
yield None
|
|
|
|
finally:
|
|
|
|
if streams is not None:
|
|
|
|
for s in streams:
|
|
|
|
mx.synchronize(s)
|
|
|
|
else:
|
|
|
|
mx.synchronize()
|
|
|
|
mx.metal.set_wired_limit(old_limit)
|
|
|
|
|
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
def _get_classes(config: dict):
|
|
|
|
"""
|
|
|
|
Retrieve the model and model args classes based on the configuration.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
config (dict): The model configuration.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A tuple containing the Model class and the ModelArgs class.
|
|
|
|
"""
|
|
|
|
model_type = config["model_type"]
|
2024-02-13 02:51:02 +08:00
|
|
|
model_type = MODEL_REMAPPING.get(model_type, model_type)
|
|
|
|
try:
|
|
|
|
arch = importlib.import_module(f"mlx_lm.models.{model_type}")
|
|
|
|
except ImportError:
|
2024-01-12 04:29:12 +08:00
|
|
|
msg = f"Model type {model_type} not supported."
|
|
|
|
logging.error(msg)
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
|
|
|
return arch.Model, arch.ModelArgs
|
|
|
|
|
|
|
|
|
2024-12-09 06:21:50 +08:00
|
|
|
def compute_bits_per_weight(model):
|
|
|
|
model_bytes = tree_reduce(
|
|
|
|
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
|
|
|
|
)
|
|
|
|
leaf_modules = tree_flatten(
|
|
|
|
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
|
|
|
|
)
|
|
|
|
model_params = sum(nparams(m) for _, m in leaf_modules)
|
|
|
|
return model_bytes * 8 / model_params
|
|
|
|
|
|
|
|
|
2024-03-02 22:28:26 +08:00
|
|
|
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
|
2024-01-12 04:29:12 +08:00
|
|
|
"""
|
|
|
|
Ensures the model is available locally. If the path does not exist locally,
|
|
|
|
it is downloaded from the Hugging Face Hub.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
|
2024-03-02 22:28:26 +08:00
|
|
|
revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
Path: The path to the model.
|
|
|
|
"""
|
|
|
|
model_path = Path(path_or_hf_repo)
|
2025-01-11 07:29:34 +08:00
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
if not model_path.exists():
|
2024-05-20 21:39:05 +08:00
|
|
|
try:
|
|
|
|
model_path = Path(
|
|
|
|
snapshot_download(
|
2025-01-11 07:29:34 +08:00
|
|
|
path_or_hf_repo,
|
2024-05-20 21:39:05 +08:00
|
|
|
revision=revision,
|
|
|
|
allow_patterns=[
|
|
|
|
"*.json",
|
|
|
|
"*.safetensors",
|
|
|
|
"*.py",
|
|
|
|
"tokenizer.model",
|
|
|
|
"*.tiktoken",
|
|
|
|
"*.txt",
|
|
|
|
],
|
|
|
|
)
|
2024-01-12 04:29:12 +08:00
|
|
|
)
|
2024-09-18 07:22:48 +08:00
|
|
|
except:
|
2024-05-20 21:39:05 +08:00
|
|
|
raise ModelNotFoundError(
|
|
|
|
f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
|
|
|
|
"Please make sure you specified the local path or Hugging Face"
|
|
|
|
" repo id correctly.\nIf you are trying to access a private or"
|
|
|
|
" gated Hugging Face repo, make sure you are authenticated:\n"
|
|
|
|
"https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login"
|
|
|
|
) from None
|
2024-01-12 04:29:12 +08:00
|
|
|
return model_path
|
|
|
|
|
|
|
|
|
2024-11-01 07:59:52 +08:00
|
|
|
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
|
|
|
|
if (
|
|
|
|
kv_bits is not None
|
|
|
|
and not isinstance(prompt_cache[0], cache.QuantizedKVCache)
|
|
|
|
and prompt_cache[0].offset > quantized_kv_start
|
|
|
|
):
|
|
|
|
for i in range(len(prompt_cache)):
|
2024-12-17 00:01:03 +08:00
|
|
|
if isinstance(prompt_cache[i], cache.KVCache):
|
|
|
|
prompt_cache[i] = prompt_cache[i].to_quantized(
|
|
|
|
group_size=kv_group_size, bits=kv_bits
|
|
|
|
)
|
2024-11-01 07:59:52 +08:00
|
|
|
|
|
|
|
|
2024-01-13 02:25:56 +08:00
|
|
|
def generate_step(
|
2024-01-24 04:44:23 +08:00
|
|
|
prompt: mx.array,
|
|
|
|
model: nn.Module,
|
2024-11-24 03:47:06 +08:00
|
|
|
*,
|
2024-12-04 08:17:14 +08:00
|
|
|
max_tokens: int = 256,
|
2024-11-24 03:47:06 +08:00
|
|
|
sampler: Optional[Callable[mx.array, mx.array]] = None,
|
|
|
|
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
2024-08-17 06:28:39 +08:00
|
|
|
max_kv_size: Optional[int] = None,
|
2024-10-08 11:45:51 +08:00
|
|
|
prompt_cache: Optional[Any] = None,
|
2024-11-24 03:47:06 +08:00
|
|
|
prefill_step_size: int = 512,
|
2024-11-01 07:59:52 +08:00
|
|
|
kv_bits: Optional[int] = None,
|
|
|
|
kv_group_size: int = 64,
|
|
|
|
quantized_kv_start: int = 0,
|
2024-12-04 08:17:14 +08:00
|
|
|
prompt_progress_callback: Optional[Callable[int, int]] = None,
|
2024-01-24 04:44:23 +08:00
|
|
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
2024-01-12 04:29:12 +08:00
|
|
|
"""
|
2024-06-04 00:04:39 +08:00
|
|
|
A generator producing token ids based on the given prompt from the model.
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
prompt (mx.array): The input prompt.
|
|
|
|
model (nn.Module): The model to use for generation.
|
2024-12-04 08:17:14 +08:00
|
|
|
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
|
|
|
|
generator. Default: ``256``.
|
2024-11-24 03:47:06 +08:00
|
|
|
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
|
|
|
|
token from a vector of log probabilities. Default: ``None``.
|
2024-11-08 08:15:24 +08:00
|
|
|
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
2024-11-24 03:47:06 +08:00
|
|
|
A list of functions that take tokens and logits and return the processed
|
|
|
|
logits. Default: ``None``.
|
2024-12-04 08:17:14 +08:00
|
|
|
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
|
|
|
entries (except the first 4 tokens) will be overwritten.
|
|
|
|
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
|
|
|
provided, the cache will be updated in place.
|
|
|
|
prefill_step_size (int): Step size for processing the prompt.
|
2024-11-01 07:59:52 +08:00
|
|
|
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
2024-11-24 03:47:06 +08:00
|
|
|
None implies no cache quantization. Default: ``None``.
|
2024-11-01 07:59:52 +08:00
|
|
|
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
|
|
|
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
2024-11-24 03:47:06 +08:00
|
|
|
when ``kv_bits`` is non-None. Default: ``0``.
|
2024-12-04 08:17:14 +08:00
|
|
|
prompt_prorgress_callback (Callable[int, int]): A call-back which takes the
|
|
|
|
prompt tokens processed so far and the total number of prompt tokens.
|
2024-02-17 13:58:17 +08:00
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
Yields:
|
2024-11-08 08:15:24 +08:00
|
|
|
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
2024-01-12 04:29:12 +08:00
|
|
|
"""
|
|
|
|
|
|
|
|
y = prompt
|
2024-09-29 01:08:49 +08:00
|
|
|
tokens = None
|
2024-08-29 13:11:45 +08:00
|
|
|
|
|
|
|
# Create the KV cache for generation
|
2024-10-08 11:45:51 +08:00
|
|
|
if prompt_cache is None:
|
2024-11-01 07:59:52 +08:00
|
|
|
prompt_cache = cache.make_prompt_cache(
|
|
|
|
model,
|
|
|
|
max_kv_size=max_kv_size,
|
|
|
|
)
|
2024-10-08 11:45:51 +08:00
|
|
|
elif len(prompt_cache) != len(model.layers):
|
|
|
|
raise ValueError("Wrong number of layers in the prompt cache.")
|
2024-02-17 13:58:17 +08:00
|
|
|
|
2025-01-11 07:27:08 +08:00
|
|
|
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
|
2024-11-24 03:47:06 +08:00
|
|
|
|
2025-01-11 07:27:08 +08:00
|
|
|
quantize_cache_fn = functools.partial(
|
|
|
|
maybe_quantize_kv_cache,
|
|
|
|
quantized_kv_start=quantized_kv_start,
|
|
|
|
kv_group_size=kv_group_size,
|
|
|
|
kv_bits=kv_bits,
|
2024-11-24 03:47:06 +08:00
|
|
|
)
|
2025-01-11 07:27:08 +08:00
|
|
|
|
|
|
|
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
|
2024-11-01 07:59:52 +08:00
|
|
|
|
2024-11-08 08:15:24 +08:00
|
|
|
def _step(y):
|
|
|
|
with mx.stream(generation_stream):
|
|
|
|
logits = model(y[None], cache=prompt_cache)
|
|
|
|
logits = logits[:, -1, :]
|
2024-02-17 13:58:17 +08:00
|
|
|
|
2024-11-08 08:15:24 +08:00
|
|
|
if logits_processors:
|
|
|
|
nonlocal tokens
|
|
|
|
tokens = mx.concat([tokens, y]) if tokens is not None else y
|
2024-09-29 01:08:49 +08:00
|
|
|
|
2024-11-08 08:15:24 +08:00
|
|
|
for processor in logits_processors:
|
|
|
|
logits = processor(tokens, logits)
|
2024-02-17 13:58:17 +08:00
|
|
|
|
2025-01-11 07:27:08 +08:00
|
|
|
quantize_cache_fn(prompt_cache)
|
2024-11-01 07:59:52 +08:00
|
|
|
|
2024-11-08 08:15:24 +08:00
|
|
|
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
|
|
|
y = sampler(logprobs)
|
|
|
|
return y, logprobs.squeeze(0)
|
2024-04-12 04:18:23 +08:00
|
|
|
|
2024-11-26 01:47:00 +08:00
|
|
|
with mx.stream(generation_stream):
|
2024-12-04 08:17:14 +08:00
|
|
|
total_prompt_tokens = y.size
|
|
|
|
prompt_processed_tokens = 0
|
2024-11-26 01:47:00 +08:00
|
|
|
while y.size > prefill_step_size:
|
|
|
|
model(y[:prefill_step_size][None], cache=prompt_cache)
|
2025-01-11 07:27:08 +08:00
|
|
|
quantize_cache_fn(prompt_cache)
|
2024-11-26 01:47:00 +08:00
|
|
|
mx.eval([c.state for c in prompt_cache])
|
2024-12-04 08:17:14 +08:00
|
|
|
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
|
|
|
|
prompt_processed_tokens += prefill_step_size
|
2024-11-26 01:47:00 +08:00
|
|
|
y = y[prefill_step_size:]
|
|
|
|
mx.metal.clear_cache()
|
|
|
|
|
|
|
|
y, logprobs = _step(y)
|
2024-04-12 04:18:23 +08:00
|
|
|
|
2024-11-06 05:04:07 +08:00
|
|
|
mx.eval(y, logprobs)
|
2024-11-02 05:15:32 +08:00
|
|
|
n = 0
|
2024-04-12 04:18:23 +08:00
|
|
|
while True:
|
2024-12-04 08:17:14 +08:00
|
|
|
if n != max_tokens:
|
|
|
|
next_y, next_logprobs = _step(y)
|
2024-11-06 05:04:07 +08:00
|
|
|
mx.eval(next_y, next_logprobs)
|
2024-12-04 08:17:14 +08:00
|
|
|
if n == 0:
|
|
|
|
mx.eval(y)
|
|
|
|
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
|
|
|
if n == max_tokens:
|
|
|
|
break
|
2024-06-24 01:35:13 +08:00
|
|
|
yield y.item(), logprobs
|
2024-11-02 05:15:32 +08:00
|
|
|
if n % 256 == 0:
|
|
|
|
mx.metal.clear_cache()
|
2024-06-24 01:35:13 +08:00
|
|
|
y, logprobs = next_y, next_logprobs
|
2024-12-04 08:17:14 +08:00
|
|
|
n += 1
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
|
2025-01-11 07:27:08 +08:00
|
|
|
def speculative_generate_step(
|
|
|
|
prompt: mx.array,
|
|
|
|
model: nn.Module,
|
|
|
|
draft_model: nn.Module,
|
|
|
|
*,
|
|
|
|
num_draft_tokens=2,
|
|
|
|
max_tokens: int = 256,
|
|
|
|
sampler: Optional[Callable[mx.array, mx.array]] = None,
|
|
|
|
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
|
|
|
prompt_cache: Optional[Any] = None,
|
|
|
|
prefill_step_size: int = 512,
|
|
|
|
kv_bits: Optional[int] = None,
|
|
|
|
kv_group_size: int = 64,
|
|
|
|
quantized_kv_start: int = 0,
|
|
|
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
|
|
|
"""
|
|
|
|
A generator producing token ids based on the given prompt from the model.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
prompt (mx.array): The input prompt.
|
|
|
|
model (nn.Module): The model to use for generation.
|
|
|
|
draft_model (nn.Module): The draft model for speculative decoding.
|
|
|
|
num_draft_tokens (int, optional): The number of draft tokens for
|
|
|
|
speculative decoding. Default: ``2``.
|
|
|
|
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
|
|
|
|
generator. Default: ``256``.
|
|
|
|
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
|
|
|
|
token from a vector of log probabilities. Default: ``None``.
|
|
|
|
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
|
|
|
A list of functions that take tokens and logits and return the processed
|
|
|
|
logits. Default: ``None``.
|
|
|
|
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
|
|
|
provided, the cache will be updated in place. The cache must be trimmable.
|
|
|
|
prefill_step_size (int): Step size for processing the prompt.
|
|
|
|
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
|
|
|
None implies no cache quantization. Default: ``None``.
|
|
|
|
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
|
|
|
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
|
|
|
when ``kv_bits`` is non-None. Default: ``0``.
|
|
|
|
|
|
|
|
Yields:
|
|
|
|
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
|
|
|
"""
|
|
|
|
|
|
|
|
y = prompt
|
|
|
|
tokens = None
|
|
|
|
|
|
|
|
# Create the KV cache for generation
|
|
|
|
if prompt_cache is None:
|
|
|
|
model_cache = cache.make_prompt_cache(model)
|
|
|
|
draft_cache = cache.make_prompt_cache(draft_model)
|
|
|
|
elif len(prompt_cache) != (len(model.layers) + len(draft_model.layers)):
|
|
|
|
raise ValueError("Wrong number of layers in the prompt cache.")
|
|
|
|
else:
|
|
|
|
model_cache = prompt_cache[: len(model.layers)]
|
|
|
|
draft_cache = prompt_cache[len(model.layers) :]
|
|
|
|
|
|
|
|
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
|
|
|
|
|
|
|
|
quantize_cache_fn = functools.partial(
|
|
|
|
maybe_quantize_kv_cache,
|
|
|
|
quantized_kv_start=quantized_kv_start,
|
|
|
|
kv_group_size=kv_group_size,
|
|
|
|
kv_bits=kv_bits,
|
|
|
|
)
|
|
|
|
|
|
|
|
def _step(model, cache, y, n_predict=1):
|
|
|
|
with mx.stream(generation_stream):
|
|
|
|
logits = model(y[None], cache=cache)
|
|
|
|
logits = logits[:, -n_predict:, :]
|
|
|
|
|
|
|
|
quantize_cache_fn(cache)
|
|
|
|
|
|
|
|
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
2025-01-28 07:40:31 +08:00
|
|
|
logprobs = logprobs.squeeze(0)
|
|
|
|
y = sampler(logprobs)
|
|
|
|
return y, logprobs
|
2025-01-11 07:27:08 +08:00
|
|
|
|
|
|
|
def _prefill(model, cache, y):
|
|
|
|
while y.size > prefill_step_size:
|
|
|
|
model(y[:prefill_step_size][None], cache=cache)
|
|
|
|
quantize_cache_fn(cache)
|
|
|
|
mx.eval([c.state for c in cache])
|
|
|
|
y = y[prefill_step_size:]
|
|
|
|
mx.metal.clear_cache()
|
|
|
|
return y
|
|
|
|
|
|
|
|
def _rewind_cache(num_draft, num_accept):
|
|
|
|
cache.trim_prompt_cache(model_cache, num_draft - num_accept)
|
|
|
|
cache.trim_prompt_cache(draft_cache, max(num_draft - num_accept - 1, 0))
|
|
|
|
|
|
|
|
def _draft_generate(y, num_draft):
|
|
|
|
if num_draft == 0:
|
|
|
|
return mx.array([], mx.uint32)
|
|
|
|
ys = []
|
|
|
|
for _ in range(num_draft):
|
|
|
|
y, _ = _step(draft_model, draft_cache, y)
|
|
|
|
mx.async_eval(y)
|
|
|
|
ys.append(y)
|
|
|
|
return mx.concatenate(ys)
|
|
|
|
|
|
|
|
with mx.stream(generation_stream):
|
|
|
|
draft_y = _prefill(draft_model, draft_cache, y)
|
|
|
|
y = _prefill(model, model_cache, y)
|
|
|
|
|
|
|
|
ntoks = 0
|
|
|
|
# Set these so the finally block doesn't raise
|
|
|
|
num_draft = 0
|
|
|
|
n = 0
|
|
|
|
try:
|
|
|
|
while True:
|
|
|
|
num_draft = min(max_tokens - ntoks, num_draft_tokens)
|
|
|
|
draft_tokens = _draft_generate(draft_y, num_draft)
|
|
|
|
y = mx.concatenate([y, draft_tokens])
|
|
|
|
|
|
|
|
tokens, logprobs = _step(model, model_cache, y, num_draft + 1)
|
|
|
|
mx.eval(tokens, draft_tokens)
|
|
|
|
draft_tokens = draft_tokens.tolist()
|
|
|
|
tokens = tokens.tolist()
|
|
|
|
n = 0
|
|
|
|
while n < num_draft:
|
|
|
|
tn, dtn, lpn = tokens[n], draft_tokens[n], logprobs[n]
|
|
|
|
if tn != dtn:
|
|
|
|
break
|
|
|
|
n += 1
|
|
|
|
ntoks += 1
|
|
|
|
yield tn, lpn
|
|
|
|
if ntoks == max_tokens:
|
|
|
|
break
|
|
|
|
if ntoks < max_tokens:
|
|
|
|
ntoks += 1
|
|
|
|
yield tokens[n], logprobs[n]
|
|
|
|
|
|
|
|
if ntoks == max_tokens:
|
|
|
|
break
|
|
|
|
|
|
|
|
y = mx.array([tokens[n]], mx.uint32)
|
|
|
|
draft_y = y
|
|
|
|
|
|
|
|
# If we accpeted all the draft tokens, include the last
|
|
|
|
# draft token in the next draft step since it hasn't been
|
|
|
|
# processed yet by the draft model
|
|
|
|
if n == num_draft:
|
|
|
|
draft_y = mx.concatenate(
|
|
|
|
[mx.array(draft_tokens[-1:], mx.uint32), draft_y]
|
|
|
|
)
|
|
|
|
|
|
|
|
_rewind_cache(num_draft, n)
|
|
|
|
finally:
|
|
|
|
_rewind_cache(num_draft, n)
|
|
|
|
|
|
|
|
|
2024-06-04 00:04:39 +08:00
|
|
|
def stream_generate(
|
|
|
|
model: nn.Module,
|
|
|
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
2024-11-27 08:51:55 +08:00
|
|
|
prompt: Union[str, mx.array, List[int]],
|
2025-01-11 07:27:08 +08:00
|
|
|
draft_model: Optional[nn.Module] = None,
|
2024-06-04 00:04:39 +08:00
|
|
|
**kwargs,
|
2024-11-26 00:10:14 +08:00
|
|
|
) -> Generator[GenerationResponse, None, None]:
|
2024-06-04 00:04:39 +08:00
|
|
|
"""
|
|
|
|
A generator producing text based on the given prompt from the model.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): The model to use for generation.
|
2024-11-08 08:15:24 +08:00
|
|
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
2025-01-11 07:27:08 +08:00
|
|
|
prompt (Union[str, mx.array, List[int]]): The input prompt string or
|
|
|
|
integer tokens.
|
|
|
|
draft_model (Optional[nn.Module]): An optional draft model. If provided
|
|
|
|
then speculative decoding is used. The draft model must use the same
|
|
|
|
tokenizer as the main model. Default: ``None``.
|
2024-06-04 00:04:39 +08:00
|
|
|
kwargs: The remaining options get passed to :func:`generate_step`.
|
|
|
|
See :func:`generate_step` for more details.
|
|
|
|
|
|
|
|
Yields:
|
2024-11-26 00:10:14 +08:00
|
|
|
GenerationResponse: An instance containing the generated text segment and
|
|
|
|
associated metadata. See :class:`GenerationResponse` for details.
|
2024-06-04 00:04:39 +08:00
|
|
|
"""
|
|
|
|
if not isinstance(tokenizer, TokenizerWrapper):
|
|
|
|
tokenizer = TokenizerWrapper(tokenizer)
|
|
|
|
|
2024-11-27 08:51:55 +08:00
|
|
|
if not isinstance(prompt, mx.array):
|
2025-01-04 02:50:59 +08:00
|
|
|
if isinstance(prompt, str):
|
|
|
|
# Try to infer if special tokens are needed
|
|
|
|
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
|
|
|
|
tokenizer.bos_token
|
|
|
|
)
|
|
|
|
prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
|
|
|
|
prompt = mx.array(prompt)
|
2024-11-27 08:51:55 +08:00
|
|
|
|
2024-06-04 00:04:39 +08:00
|
|
|
detokenizer = tokenizer.detokenizer
|
|
|
|
|
2025-01-11 07:27:08 +08:00
|
|
|
if draft_model is None:
|
|
|
|
kwargs.pop("num_draft_tokens", None)
|
|
|
|
token_generator = generate_step(prompt, model, **kwargs)
|
|
|
|
else:
|
|
|
|
kwargs.pop("max_kv_size", None)
|
|
|
|
token_generator = speculative_generate_step(
|
|
|
|
prompt, model, draft_model, **kwargs
|
|
|
|
)
|
2024-11-08 08:15:24 +08:00
|
|
|
with wired_limit(model, [generation_stream]):
|
|
|
|
detokenizer.reset()
|
2024-11-24 03:47:06 +08:00
|
|
|
tic = time.perf_counter()
|
2025-01-11 07:27:08 +08:00
|
|
|
for n, (token, logprobs) in enumerate(token_generator):
|
2024-11-24 03:47:06 +08:00
|
|
|
if n == 0:
|
|
|
|
prompt_time = time.perf_counter() - tic
|
|
|
|
prompt_tps = prompt.size / prompt_time
|
|
|
|
tic = time.perf_counter()
|
2024-12-10 00:53:58 +08:00
|
|
|
if token in tokenizer.eos_token_ids:
|
2024-11-08 08:15:24 +08:00
|
|
|
break
|
2024-06-04 00:04:39 +08:00
|
|
|
|
2024-11-08 08:15:24 +08:00
|
|
|
detokenizer.add_token(token)
|
|
|
|
|
2024-11-24 03:47:06 +08:00
|
|
|
yield GenerationResponse(
|
|
|
|
text=detokenizer.last_segment,
|
|
|
|
token=token,
|
|
|
|
logprobs=logprobs,
|
|
|
|
prompt_tokens=prompt.size,
|
|
|
|
prompt_tps=prompt_tps,
|
|
|
|
generation_tokens=n + 1,
|
|
|
|
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
|
|
|
peak_memory=mx.metal.get_peak_memory() / 1e9,
|
2024-12-13 02:37:40 +08:00
|
|
|
finish_reason=None,
|
2024-11-24 03:47:06 +08:00
|
|
|
)
|
2024-11-08 08:15:24 +08:00
|
|
|
|
|
|
|
detokenizer.finalize()
|
2024-11-24 03:47:06 +08:00
|
|
|
yield GenerationResponse(
|
|
|
|
text=detokenizer.last_segment,
|
|
|
|
token=token,
|
|
|
|
logprobs=logprobs,
|
|
|
|
prompt_tokens=prompt.size,
|
|
|
|
prompt_tps=prompt_tps,
|
|
|
|
generation_tokens=n + 1,
|
|
|
|
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
|
|
|
peak_memory=mx.metal.get_peak_memory() / 1e9,
|
2024-12-13 02:37:40 +08:00
|
|
|
finish_reason="stop" if token in tokenizer.eos_token_ids else "length",
|
2024-11-24 03:47:06 +08:00
|
|
|
)
|
2024-06-04 00:04:39 +08:00
|
|
|
|
|
|
|
|
2024-01-13 02:25:56 +08:00
|
|
|
def generate(
|
|
|
|
model: nn.Module,
|
2024-04-09 13:36:01 +08:00
|
|
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
2025-01-04 02:50:59 +08:00
|
|
|
prompt: Union[str, List[int]],
|
2024-01-13 02:25:56 +08:00
|
|
|
verbose: bool = False,
|
Refactoring of mlx_lm example (#501)
* Use named tuple from typing for typehints
* Add type hints
* Simplify expression
* Type hint fix
* Improved do_POST logic
Use a map of endpoints to methods to reduce redundancy in code
* Fix format
* Improve redundancy
Call method dynamically instead of writing out all arguments twice
* Send response instead of returning
* Fix typo
* Revert change
* Make adapter_file as Optional
* Mark formatter as optional
* format
* Create message generator
Store response data that stays static for the duration of the response inside of the object:
system_fingerprint
request_id
object_type
requested_model
Created a message generator, that dynamically creates messages from the metadata stored inside of the object, and the data from the model pipeline
* Remove leftover
* Update parameters to reflect new object structure
No longer pass all arguments between functions, but use the stores values inside of the object
* Parse body before calling request specific methods
* Call super init
* Update server.py
* Fixed outdated documentation parameter name
* Add documentation
* Fix sending headers twice
During testing I found that when using the streaming option, headers have always been sent twice. This should fix that
* Simplify streaming code by using guard clauses
Don't wrap wfile writes in try blocks, the server class has its own try block to prevent crashing
* Bug fix
* Use Content-Length header
Let the completion type specific methods finish sending the headers. This allows us to send the Content-Length header as the model returns a completion.
* Update utils.py
* Add top_p documentation
* Type hint model and tokenizer as required
* Use static system fingerprint
System fingerprint now stays the same across requests
* Make type hint more specific
* Bug Fix
Supplying less than 2 models to merge would raise ValueError and calls len on unbound "models". Should be "model_paths" instead.
Mark upload_repo as optional
* Move more of the shared code into do_POST
Processing stop_id_sequences is done no matter the request endpoint or type, move it into the shared section. handle_ methods now just return the prompt in mx.array form.
* Store stop_id_sequences as lists instead of np
During testing I found that letting the tokenizer return values as python lists and converting them to mlx arrays was around 20% faster than having the tokenizer convert them to np, and from np to mlx. This allows makes it so numpy no longer needs to be imported.
* Update stop_id_sequences docs
* Turn if check to non-inclusive
Only continue if buffer is smaller
* Documentation fix
* Cleared method names
Instead of handle_stream and generate_competion, we should name it handle_completion.
Instead of handle_completions and handle_chat_completions, we should name it handle_text_completions, since both are completions, calling it text completions should make it more descriptive
* Make comment clearer
* fix format
* format
2024-03-06 22:24:31 +08:00
|
|
|
formatter: Optional[Callable] = None,
|
2024-06-04 00:04:39 +08:00
|
|
|
**kwargs,
|
2024-11-08 08:15:24 +08:00
|
|
|
) -> str:
|
2024-01-13 02:25:56 +08:00
|
|
|
"""
|
2024-06-04 00:04:39 +08:00
|
|
|
Generate a complete response from the model.
|
2024-01-13 02:25:56 +08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): The language model.
|
|
|
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
2025-01-04 02:50:59 +08:00
|
|
|
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
|
2024-06-04 00:04:39 +08:00
|
|
|
verbose (bool): If ``True``, print tokens and timing information.
|
|
|
|
Default: ``False``.
|
2024-11-24 03:47:06 +08:00
|
|
|
kwargs: The remaining options get passed to :func:`stream_generate`.
|
|
|
|
See :func:`stream_generate` for more details.
|
2024-01-13 02:25:56 +08:00
|
|
|
"""
|
2024-11-24 03:47:06 +08:00
|
|
|
if formatter is not None:
|
|
|
|
print(
|
|
|
|
"[Warning] Text formatting is deprecated and no longer used. "
|
|
|
|
"The argument will be removed in a future version."
|
|
|
|
)
|
2024-01-24 04:44:23 +08:00
|
|
|
if verbose:
|
|
|
|
print("=" * 10)
|
|
|
|
|
2024-11-24 03:47:06 +08:00
|
|
|
text = ""
|
|
|
|
for response in stream_generate(model, tokenizer, prompt, **kwargs):
|
2024-01-13 02:25:56 +08:00
|
|
|
if verbose:
|
2024-11-24 03:47:06 +08:00
|
|
|
print(response.text, end="", flush=True)
|
|
|
|
text += response.text
|
2024-01-24 04:44:23 +08:00
|
|
|
|
2024-11-24 03:47:06 +08:00
|
|
|
if verbose:
|
|
|
|
print()
|
|
|
|
print("=" * 10)
|
|
|
|
if len(text) == 0:
|
|
|
|
print("No text generated for this prompt")
|
|
|
|
return
|
|
|
|
print(
|
|
|
|
f"Prompt: {response.prompt_tokens} tokens, "
|
|
|
|
f"{response.prompt_tps:.3f} tokens-per-sec"
|
|
|
|
)
|
|
|
|
print(
|
|
|
|
f"Generation: {response.generation_tokens} tokens, "
|
|
|
|
f"{response.generation_tps:.3f} tokens-per-sec"
|
|
|
|
)
|
|
|
|
print(f"Peak memory: {response.peak_memory:.3f} GB")
|
|
|
|
return text
|
2024-01-13 02:25:56 +08:00
|
|
|
|
|
|
|
|
2024-04-19 09:16:10 +08:00
|
|
|
def load_config(model_path: Path) -> dict:
|
|
|
|
try:
|
|
|
|
with open(model_path / "config.json", "r") as f:
|
|
|
|
config = json.load(f)
|
|
|
|
except FileNotFoundError:
|
|
|
|
logging.error(f"Config file not found in {model_path}")
|
|
|
|
raise
|
|
|
|
return config
|
|
|
|
|
|
|
|
|
2024-05-11 01:13:34 +08:00
|
|
|
def load_model(
|
|
|
|
model_path: Path,
|
|
|
|
lazy: bool = False,
|
2025-02-03 05:58:44 +08:00
|
|
|
strict: bool = True,
|
2024-11-06 05:04:07 +08:00
|
|
|
sequential_load: bool = False,
|
2024-05-11 01:13:34 +08:00
|
|
|
model_config: dict = {},
|
2024-07-26 02:01:17 +08:00
|
|
|
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
|
2024-05-11 01:13:34 +08:00
|
|
|
) -> nn.Module:
|
2024-01-12 04:29:12 +08:00
|
|
|
"""
|
2024-01-20 13:07:21 +08:00
|
|
|
Load and initialize the model from a given path.
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
Args:
|
2024-01-20 13:07:21 +08:00
|
|
|
model_path (Path): The path to load the model from.
|
2024-02-21 05:36:55 +08:00
|
|
|
lazy (bool): If False eval the model parameters to make sure they are
|
|
|
|
loaded in memory before returning, otherwise they will be loaded
|
|
|
|
when needed. Default: ``False``
|
2025-02-03 05:58:44 +08:00
|
|
|
strict (bool): Whether or not to raise an exception if weights don't
|
|
|
|
match. Default: ``True``
|
2024-12-10 00:53:58 +08:00
|
|
|
model_config (dict, optional): Optional configuration parameters for the
|
|
|
|
model. Defaults to an empty dictionary.
|
2024-07-26 02:01:17 +08:00
|
|
|
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
|
|
|
|
A function that returns the model class and model args class given a config.
|
2024-12-10 00:53:58 +08:00
|
|
|
Defaults to the ``_get_classes`` function.
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
Returns:
|
2024-01-20 13:07:21 +08:00
|
|
|
nn.Module: The loaded and initialized model.
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
Raises:
|
2024-01-20 13:07:21 +08:00
|
|
|
FileNotFoundError: If the weight files (.safetensors) are not found.
|
|
|
|
ValueError: If the model class or args class are not found or cannot be instantiated.
|
2024-01-12 04:29:12 +08:00
|
|
|
"""
|
2024-04-19 09:16:10 +08:00
|
|
|
config = load_config(model_path)
|
2024-05-11 01:13:34 +08:00
|
|
|
config.update(model_config)
|
2024-01-20 13:07:21 +08:00
|
|
|
|
2024-04-26 05:16:13 +08:00
|
|
|
weight_files = glob.glob(str(model_path / "model*.safetensors"))
|
|
|
|
|
|
|
|
if not weight_files:
|
|
|
|
# Try weight for back-compat
|
|
|
|
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
|
|
|
|
|
2025-02-03 05:58:44 +08:00
|
|
|
if not weight_files and strict:
|
2024-01-12 04:29:12 +08:00
|
|
|
logging.error(f"No safetensors found in {model_path}")
|
|
|
|
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
2024-01-20 13:07:21 +08:00
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
weights = {}
|
|
|
|
for wf in weight_files:
|
|
|
|
weights.update(mx.load(wf))
|
|
|
|
|
2024-07-26 02:01:17 +08:00
|
|
|
model_class, model_args_class = get_model_classes(config=config)
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
model_args = model_args_class.from_dict(config)
|
|
|
|
model = model_class(model_args)
|
|
|
|
|
2024-03-13 12:34:32 +08:00
|
|
|
if hasattr(model, "sanitize"):
|
|
|
|
weights = model.sanitize(weights)
|
|
|
|
|
2024-04-19 09:16:10 +08:00
|
|
|
if (quantization := config.get("quantization", None)) is not None:
|
2024-12-09 06:21:50 +08:00
|
|
|
|
2024-05-22 06:58:08 +08:00
|
|
|
def class_predicate(p, m):
|
2024-12-09 06:21:50 +08:00
|
|
|
# Handle custom per layer quantizations
|
|
|
|
if p in config["quantization"]:
|
|
|
|
return config["quantization"][p]
|
2024-05-22 06:58:08 +08:00
|
|
|
if not hasattr(m, "to_quantized"):
|
|
|
|
return False
|
2024-12-09 06:21:50 +08:00
|
|
|
# Handle legacy models which may not have everything quantized
|
2024-05-22 06:58:08 +08:00
|
|
|
return f"{p}.scales" in weights
|
|
|
|
|
2024-04-19 09:16:10 +08:00
|
|
|
nn.quantize(
|
|
|
|
model,
|
2024-12-09 06:21:50 +08:00
|
|
|
group_size=quantization["group_size"],
|
|
|
|
bits=quantization["bits"],
|
2024-04-19 09:16:10 +08:00
|
|
|
class_predicate=class_predicate,
|
|
|
|
)
|
2024-01-12 04:29:12 +08:00
|
|
|
|
2025-02-03 05:58:44 +08:00
|
|
|
model.load_weights(list(weights.items()), strict=strict)
|
2024-01-12 04:29:12 +08:00
|
|
|
|
2024-07-16 04:24:50 +08:00
|
|
|
if mx.distributed.init().size() > 1:
|
|
|
|
if not hasattr(model, "shard"):
|
|
|
|
raise RuntimeError("Model doesn't support distributed inference.")
|
|
|
|
model.shard()
|
|
|
|
|
2024-02-21 05:36:55 +08:00
|
|
|
if not lazy:
|
2024-11-06 05:04:07 +08:00
|
|
|
weights.clear()
|
|
|
|
if sequential_load:
|
|
|
|
for layer in model.layers:
|
|
|
|
mx.eval(layer.parameters())
|
2024-02-21 05:36:55 +08:00
|
|
|
mx.eval(model.parameters())
|
2024-01-20 13:07:21 +08:00
|
|
|
|
2024-01-24 00:44:37 +08:00
|
|
|
model.eval()
|
2024-12-10 00:53:58 +08:00
|
|
|
return model, config
|
2024-01-20 13:07:21 +08:00
|
|
|
|
|
|
|
|
2024-01-23 07:00:07 +08:00
|
|
|
def load(
|
2024-02-21 05:36:55 +08:00
|
|
|
path_or_hf_repo: str,
|
|
|
|
tokenizer_config={},
|
2024-05-11 01:13:34 +08:00
|
|
|
model_config={},
|
2024-04-03 04:52:53 +08:00
|
|
|
adapter_path: Optional[str] = None,
|
2024-02-21 05:36:55 +08:00
|
|
|
lazy: bool = False,
|
2024-11-06 05:04:07 +08:00
|
|
|
sequential_load: bool = False,
|
2024-04-19 05:26:18 +08:00
|
|
|
) -> Tuple[nn.Module, TokenizerWrapper]:
|
2024-01-20 13:07:21 +08:00
|
|
|
"""
|
2024-01-24 00:44:37 +08:00
|
|
|
Load the model and tokenizer from a given path or a huggingface repository.
|
2024-01-20 13:07:21 +08:00
|
|
|
|
|
|
|
Args:
|
Refactoring of mlx_lm example (#501)
* Use named tuple from typing for typehints
* Add type hints
* Simplify expression
* Type hint fix
* Improved do_POST logic
Use a map of endpoints to methods to reduce redundancy in code
* Fix format
* Improve redundancy
Call method dynamically instead of writing out all arguments twice
* Send response instead of returning
* Fix typo
* Revert change
* Make adapter_file as Optional
* Mark formatter as optional
* format
* Create message generator
Store response data that stays static for the duration of the response inside of the object:
system_fingerprint
request_id
object_type
requested_model
Created a message generator, that dynamically creates messages from the metadata stored inside of the object, and the data from the model pipeline
* Remove leftover
* Update parameters to reflect new object structure
No longer pass all arguments between functions, but use the stores values inside of the object
* Parse body before calling request specific methods
* Call super init
* Update server.py
* Fixed outdated documentation parameter name
* Add documentation
* Fix sending headers twice
During testing I found that when using the streaming option, headers have always been sent twice. This should fix that
* Simplify streaming code by using guard clauses
Don't wrap wfile writes in try blocks, the server class has its own try block to prevent crashing
* Bug fix
* Use Content-Length header
Let the completion type specific methods finish sending the headers. This allows us to send the Content-Length header as the model returns a completion.
* Update utils.py
* Add top_p documentation
* Type hint model and tokenizer as required
* Use static system fingerprint
System fingerprint now stays the same across requests
* Make type hint more specific
* Bug Fix
Supplying less than 2 models to merge would raise ValueError and calls len on unbound "models". Should be "model_paths" instead.
Mark upload_repo as optional
* Move more of the shared code into do_POST
Processing stop_id_sequences is done no matter the request endpoint or type, move it into the shared section. handle_ methods now just return the prompt in mx.array form.
* Store stop_id_sequences as lists instead of np
During testing I found that letting the tokenizer return values as python lists and converting them to mlx arrays was around 20% faster than having the tokenizer convert them to np, and from np to mlx. This allows makes it so numpy no longer needs to be imported.
* Update stop_id_sequences docs
* Turn if check to non-inclusive
Only continue if buffer is smaller
* Documentation fix
* Cleared method names
Instead of handle_stream and generate_competion, we should name it handle_completion.
Instead of handle_completions and handle_chat_completions, we should name it handle_text_completions, since both are completions, calling it text completions should make it more descriptive
* Make comment clearer
* fix format
* format
2024-03-06 22:24:31 +08:00
|
|
|
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
|
2024-01-23 07:00:07 +08:00
|
|
|
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
|
|
|
Defaults to an empty dictionary.
|
2024-05-11 01:13:34 +08:00
|
|
|
model_config(dict, optional): Configuration parameters specifically for the model.
|
|
|
|
Defaults to an empty dictionary.
|
2024-04-03 04:52:53 +08:00
|
|
|
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
|
|
|
|
to the model. Default: ``None``.
|
2025-01-10 07:55:53 +08:00
|
|
|
lazy (bool): If ``False`` eval the model parameters to make sure they are
|
2024-02-21 05:36:55 +08:00
|
|
|
loaded in memory before returning, otherwise they will be loaded
|
|
|
|
when needed. Default: ``False``
|
2024-11-06 05:04:07 +08:00
|
|
|
sequential_load (bool): If True then load each layer sequentially to
|
|
|
|
ensure that we are not wasting memory.
|
2024-01-20 13:07:21 +08:00
|
|
|
Returns:
|
2024-04-19 05:26:18 +08:00
|
|
|
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
|
2024-01-20 13:07:21 +08:00
|
|
|
|
|
|
|
Raises:
|
|
|
|
FileNotFoundError: If config file or safetensors are not found.
|
|
|
|
ValueError: If model class or args class are not found.
|
|
|
|
"""
|
|
|
|
model_path = get_model_path(path_or_hf_repo)
|
|
|
|
|
2024-11-06 05:04:07 +08:00
|
|
|
model, config = load_model(model_path, sequential_load, lazy)
|
2024-04-03 04:52:53 +08:00
|
|
|
if adapter_path is not None:
|
2024-09-30 08:12:47 +08:00
|
|
|
model = load_adapters(model, adapter_path)
|
2024-01-25 00:11:25 +08:00
|
|
|
model.eval()
|
2024-12-10 00:53:58 +08:00
|
|
|
tokenizer = load_tokenizer(
|
|
|
|
model_path, tokenizer_config, eos_token_ids=config.get("eos_token_id", None)
|
|
|
|
)
|
2024-01-24 11:47:39 +08:00
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
return model, tokenizer
|
2024-01-24 00:44:37 +08:00
|
|
|
|
|
|
|
|
|
|
|
def fetch_from_hub(
|
2024-02-21 05:36:55 +08:00
|
|
|
model_path: Path, lazy: bool = False
|
2024-02-28 11:40:42 +08:00
|
|
|
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
2024-12-10 00:53:58 +08:00
|
|
|
model, config = load_model(model_path, lazy)
|
|
|
|
tokenizer = load_tokenizer(
|
|
|
|
model_path, eos_token_ids=config.get("eos_token_id", None)
|
|
|
|
)
|
2024-04-19 09:16:10 +08:00
|
|
|
return model, config, tokenizer
|
2024-01-24 00:44:37 +08:00
|
|
|
|
|
|
|
|
|
|
|
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
|
|
|
|
"""
|
|
|
|
Splits the weights into smaller shards.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
weights (dict): Model weights.
|
|
|
|
max_file_size_gb (int): Maximum size of each shard in gigabytes.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
list: List of weight shards.
|
|
|
|
"""
|
|
|
|
max_file_size_bytes = max_file_size_gb << 30
|
|
|
|
shards = []
|
|
|
|
shard, shard_size = {}, 0
|
|
|
|
for k, v in weights.items():
|
2024-01-27 05:54:49 +08:00
|
|
|
if shard_size + v.nbytes > max_file_size_bytes:
|
2024-01-24 00:44:37 +08:00
|
|
|
shards.append(shard)
|
|
|
|
shard, shard_size = {}, 0
|
|
|
|
shard[k] = v
|
2024-01-27 05:54:49 +08:00
|
|
|
shard_size += v.nbytes
|
2024-01-24 00:44:37 +08:00
|
|
|
shards.append(shard)
|
|
|
|
return shards
|
|
|
|
|
|
|
|
|
|
|
|
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
|
|
|
"""
|
|
|
|
Uploads the model to Hugging Face hub.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
path (str): Local path to the model.
|
|
|
|
upload_repo (str): Name of the HF repo to upload to.
|
|
|
|
hf_path (str): Path to the original Hugging Face model.
|
|
|
|
"""
|
|
|
|
import os
|
|
|
|
|
|
|
|
from huggingface_hub import HfApi, ModelCard, logging
|
|
|
|
|
2024-03-20 10:42:03 +08:00
|
|
|
from . import __version__
|
|
|
|
|
2024-01-24 00:44:37 +08:00
|
|
|
card = ModelCard.load(hf_path)
|
|
|
|
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
|
2024-09-04 21:19:32 +08:00
|
|
|
card.data.base_model = hf_path
|
2024-03-01 14:23:01 +08:00
|
|
|
card.text = dedent(
|
|
|
|
f"""
|
|
|
|
# {upload_repo}
|
2024-05-08 23:18:13 +08:00
|
|
|
|
2024-11-08 08:15:24 +08:00
|
|
|
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was
|
|
|
|
converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path})
|
|
|
|
using mlx-lm version **{__version__}**.
|
2024-05-08 23:18:13 +08:00
|
|
|
|
2024-03-01 14:23:01 +08:00
|
|
|
## Use with mlx
|
|
|
|
|
|
|
|
```bash
|
|
|
|
pip install mlx-lm
|
|
|
|
```
|
|
|
|
|
|
|
|
```python
|
|
|
|
from mlx_lm import load, generate
|
|
|
|
|
|
|
|
model, tokenizer = load("{upload_repo}")
|
2024-09-07 21:06:15 +08:00
|
|
|
|
2025-01-10 07:55:53 +08:00
|
|
|
prompt = "hello"
|
2024-09-07 21:06:15 +08:00
|
|
|
|
2025-01-04 02:50:59 +08:00
|
|
|
if tokenizer.chat_template is not None:
|
2024-09-08 05:46:57 +08:00
|
|
|
messages = [{{"role": "user", "content": prompt}}]
|
2024-09-07 21:06:15 +08:00
|
|
|
prompt = tokenizer.apply_chat_template(
|
2025-01-04 02:50:59 +08:00
|
|
|
messages, add_generation_prompt=True
|
2024-09-07 21:06:15 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
response = generate(model, tokenizer, prompt=prompt, verbose=True)
|
2024-03-01 14:23:01 +08:00
|
|
|
```
|
|
|
|
"""
|
|
|
|
)
|
2024-01-24 00:44:37 +08:00
|
|
|
card.save(os.path.join(path, "README.md"))
|
|
|
|
|
|
|
|
logging.set_verbosity_info()
|
|
|
|
|
|
|
|
api = HfApi()
|
|
|
|
api.create_repo(repo_id=upload_repo, exist_ok=True)
|
2025-01-08 01:18:31 +08:00
|
|
|
api.upload_large_folder(
|
2024-01-24 00:44:37 +08:00
|
|
|
folder_path=path,
|
|
|
|
repo_id=upload_repo,
|
|
|
|
repo_type="model",
|
|
|
|
)
|
2024-03-06 13:51:31 +08:00
|
|
|
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
|
2024-01-24 00:44:37 +08:00
|
|
|
|
|
|
|
|
2024-02-21 05:36:55 +08:00
|
|
|
def save_weights(
|
|
|
|
save_path: Union[str, Path],
|
|
|
|
weights: Dict[str, Any],
|
|
|
|
*,
|
|
|
|
donate_weights: bool = False,
|
|
|
|
) -> None:
|
2024-01-24 00:44:37 +08:00
|
|
|
"""Save model weights into specified directory."""
|
|
|
|
if isinstance(save_path, str):
|
|
|
|
save_path = Path(save_path)
|
|
|
|
save_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
shards = make_shards(weights)
|
|
|
|
shards_count = len(shards)
|
|
|
|
shard_file_format = (
|
|
|
|
"model-{:05d}-of-{:05d}.safetensors"
|
|
|
|
if shards_count > 1
|
|
|
|
else "model.safetensors"
|
|
|
|
)
|
|
|
|
|
2024-02-06 21:32:15 +08:00
|
|
|
total_size = sum(v.nbytes for v in weights.values())
|
|
|
|
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
|
|
|
|
2024-02-21 05:36:55 +08:00
|
|
|
# Write the weights and make sure no references are kept other than the
|
|
|
|
# necessary ones
|
|
|
|
if donate_weights:
|
|
|
|
weights.clear()
|
2024-03-01 01:40:04 +08:00
|
|
|
del weights
|
2024-02-21 05:36:55 +08:00
|
|
|
|
|
|
|
for i in range(len(shards)):
|
|
|
|
shard = shards[i]
|
|
|
|
shards[i] = None
|
2024-01-24 00:44:37 +08:00
|
|
|
shard_name = shard_file_format.format(i + 1, shards_count)
|
2024-02-06 21:32:15 +08:00
|
|
|
shard_path = save_path / shard_name
|
|
|
|
|
2024-02-28 23:29:00 +08:00
|
|
|
mx.save_safetensors(str(shard_path), shard, metadata={"format": "mlx"})
|
2024-02-06 21:32:15 +08:00
|
|
|
|
|
|
|
for weight_name in shard.keys():
|
|
|
|
index_data["weight_map"][weight_name] = shard_name
|
2024-02-21 05:36:55 +08:00
|
|
|
del shard
|
2024-02-06 21:32:15 +08:00
|
|
|
|
|
|
|
index_data["weight_map"] = {
|
|
|
|
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
|
|
|
|
}
|
|
|
|
|
|
|
|
with open(save_path / "model.safetensors.index.json", "w") as f:
|
|
|
|
json.dump(
|
|
|
|
index_data,
|
|
|
|
f,
|
|
|
|
indent=4,
|
|
|
|
)
|
2024-02-28 00:47:56 +08:00
|
|
|
|
|
|
|
|
|
|
|
def quantize_model(
|
2024-12-09 06:21:50 +08:00
|
|
|
model: nn.Module,
|
|
|
|
config: dict,
|
|
|
|
q_group_size: int,
|
|
|
|
q_bits: int,
|
|
|
|
quant_predicate: Optional[
|
|
|
|
Callable[[str, nn.Module, dict], Union[bool, dict]]
|
|
|
|
] = None,
|
2024-02-28 00:47:56 +08:00
|
|
|
) -> Tuple:
|
|
|
|
"""
|
|
|
|
Applies quantization to the model weights.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): The model to be quantized.
|
|
|
|
config (dict): Model configuration.
|
|
|
|
q_group_size (int): Group size for quantization.
|
|
|
|
q_bits (int): Bits per weight for quantization.
|
2024-12-09 06:21:50 +08:00
|
|
|
quant_predicate (Callable): A callable that decides how
|
|
|
|
to quantize each layer based on the path.
|
|
|
|
Accepts the layer `path`, the `module` and the model `config`.
|
|
|
|
Returns either a bool to signify quantize/no quantize or
|
|
|
|
a dict of quantization parameters to pass to `to_quantized`.
|
2024-02-28 00:47:56 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tuple: Tuple containing quantized weights and config.
|
|
|
|
"""
|
|
|
|
quantized_config = copy.deepcopy(config)
|
|
|
|
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
2024-12-09 06:21:50 +08:00
|
|
|
|
|
|
|
# Add any custom quantization parameters to the config as we go
|
|
|
|
def _class_predicate(p, m):
|
|
|
|
bool_or_params = quant_predicate(p, m, config)
|
|
|
|
quantized_config["quantization"][p] = bool_or_params
|
|
|
|
return bool_or_params
|
|
|
|
|
|
|
|
nn.quantize(
|
|
|
|
model,
|
|
|
|
q_group_size,
|
|
|
|
q_bits,
|
|
|
|
class_predicate=_class_predicate if quant_predicate else None,
|
|
|
|
)
|
2024-09-04 21:19:32 +08:00
|
|
|
# support hf model tree #957
|
|
|
|
quantized_config["quantization_config"] = quantized_config["quantization"]
|
2024-02-28 00:47:56 +08:00
|
|
|
quantized_weights = dict(tree_flatten(model.parameters()))
|
|
|
|
|
2024-12-09 06:21:50 +08:00
|
|
|
bpw = compute_bits_per_weight(model)
|
|
|
|
print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.")
|
|
|
|
|
2024-02-28 00:47:56 +08:00
|
|
|
return quantized_weights, quantized_config
|
|
|
|
|
|
|
|
|
2024-03-14 21:36:05 +08:00
|
|
|
def save_config(
|
|
|
|
config: dict,
|
|
|
|
config_path: Union[str, Path],
|
|
|
|
) -> None:
|
|
|
|
"""Save the model configuration to the ``config_path``.
|
|
|
|
|
|
|
|
The final configuration will be sorted before saving for better readability.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
config (dict): The model configuration.
|
|
|
|
config_path (Union[str, Path]): Model configuration file path.
|
|
|
|
"""
|
|
|
|
# Clean unused keys
|
|
|
|
config.pop("_name_or_path", None)
|
|
|
|
|
|
|
|
# sort the config for better readability
|
|
|
|
config = dict(sorted(config.items()))
|
|
|
|
|
|
|
|
# write the updated config to the config_path (if provided)
|
|
|
|
with open(config_path, "w") as fid:
|
|
|
|
json.dump(config, fid, indent=4)
|
|
|
|
|
|
|
|
|
2024-02-28 00:47:56 +08:00
|
|
|
def convert(
|
|
|
|
hf_path: str,
|
|
|
|
mlx_path: str = "mlx_model",
|
|
|
|
quantize: bool = False,
|
|
|
|
q_group_size: int = 64,
|
|
|
|
q_bits: int = 4,
|
|
|
|
dtype: str = "float16",
|
|
|
|
upload_repo: str = None,
|
2024-03-02 22:28:26 +08:00
|
|
|
revision: Optional[str] = None,
|
2024-03-20 10:50:08 +08:00
|
|
|
dequantize: bool = False,
|
2024-12-09 06:21:50 +08:00
|
|
|
quant_predicate: Optional[
|
|
|
|
Callable[[str, nn.Module, dict], Union[bool, dict]]
|
|
|
|
] = None,
|
2024-02-28 00:47:56 +08:00
|
|
|
):
|
2024-08-15 01:22:04 +08:00
|
|
|
# Check the save path is empty
|
|
|
|
if isinstance(mlx_path, str):
|
|
|
|
mlx_path = Path(mlx_path)
|
|
|
|
|
|
|
|
if mlx_path.exists():
|
|
|
|
raise ValueError(
|
|
|
|
f"Cannot save to the path {mlx_path} as it already exists."
|
|
|
|
" Please delete the file/directory or specify a new path to save to."
|
|
|
|
)
|
|
|
|
|
2024-02-28 00:47:56 +08:00
|
|
|
print("[INFO] Loading")
|
2024-03-02 22:28:26 +08:00
|
|
|
model_path = get_model_path(hf_path, revision=revision)
|
2024-02-28 00:47:56 +08:00
|
|
|
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
|
|
|
|
|
|
|
weights = dict(tree_flatten(model.parameters()))
|
2024-10-23 00:56:45 +08:00
|
|
|
dtype = getattr(mx, dtype)
|
2024-02-28 00:47:56 +08:00
|
|
|
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
|
|
|
|
2024-03-20 10:50:08 +08:00
|
|
|
if quantize and dequantize:
|
|
|
|
raise ValueError("Choose either quantize or dequantize, not both.")
|
|
|
|
|
2024-02-28 00:47:56 +08:00
|
|
|
if quantize:
|
|
|
|
print("[INFO] Quantizing")
|
|
|
|
model.load_weights(list(weights.items()))
|
2024-12-09 06:21:50 +08:00
|
|
|
weights, config = quantize_model(
|
|
|
|
model, config, q_group_size, q_bits, quant_predicate=quant_predicate
|
|
|
|
)
|
2024-02-28 00:47:56 +08:00
|
|
|
|
2024-03-20 10:50:08 +08:00
|
|
|
if dequantize:
|
|
|
|
print("[INFO] Dequantizing")
|
|
|
|
model = dequantize_model(model)
|
|
|
|
weights = dict(tree_flatten(model.parameters()))
|
|
|
|
|
2024-02-28 00:47:56 +08:00
|
|
|
del model
|
|
|
|
save_weights(mlx_path, weights, donate_weights=True)
|
|
|
|
|
|
|
|
py_files = glob.glob(str(model_path / "*.py"))
|
|
|
|
for file in py_files:
|
|
|
|
shutil.copy(file, mlx_path)
|
|
|
|
|
|
|
|
tokenizer.save_pretrained(mlx_path)
|
|
|
|
|
2024-03-14 21:36:05 +08:00
|
|
|
save_config(config, config_path=mlx_path / "config.json")
|
2024-02-28 00:47:56 +08:00
|
|
|
|
|
|
|
if upload_repo is not None:
|
|
|
|
upload_to_hub(mlx_path, upload_repo, hf_path)
|