# Copyright © 2023-2024 Apple Inc. import contextlib import copy import glob import importlib import json import logging import shutil import time from dataclasses import dataclass from pathlib import Path from textwrap import dedent from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union import mlx.core as mx import mlx.nn as nn from huggingface_hub import snapshot_download from mlx.utils import tree_flatten, tree_reduce from transformers import PreTrainedTokenizer # Local imports from .models import cache from .sample_utils import make_logits_processors, make_sampler from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import dequantize as dequantize_model from .tuner.utils import load_adapters, nparams # Constants MODEL_REMAPPING = { "mistral": "llama", # mistral is compatible with llama "phi-msft": "phixtral", "falcon_mamba": "mamba", } MAX_FILE_SIZE_GB = 5 # A stream on the default device just for generation generation_stream = mx.new_stream(mx.default_device()) class ModelNotFoundError(Exception): def __init__(self, message): self.message = message super().__init__(self.message) @dataclass class GenerationResponse: """ The output of :func:`stream_generate`. Args: text (str): The next segment of decoded text. This can be an empty string. token (int): The next token. logprobs (mx.array): A vector of log probabilities. prompt_tokens (int): The number of tokens in the prompt. prompt_tps (float): The prompt processing tokens-per-second. generation_tokens (int): The number of generated tokens. generation_tps (float): The tokens-per-second for generation. peak_memory (float): The peak memory used so far in GB. finish_reason (str): The reason the response is being sent: "length", "stop" or `None` """ text: str token: int logprobs: mx.array prompt_tokens: int prompt_tps: float generation_tokens: int generation_tps: float peak_memory: float finish_reason: Optional[str] = None @contextlib.contextmanager def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None): """ A context manager to temporarily change the wired limit. Note, the wired limit should not be changed during an async eval. If an async eval could be running pass in the streams to synchronize with prior to exiting the context manager. """ model_bytes = tree_reduce( lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0 ) max_rec_size = mx.metal.device_info()["max_recommended_working_set_size"] if model_bytes > 0.9 * max_rec_size: model_mb = model_bytes // 2**20 max_rec_mb = max_rec_size // 2**20 print( f"[WARNING] Generating with a model that requires {model_mb} MB " f"which is close to the maximum recommended size of {max_rec_mb} " "MB. This can be slow. See the documentation for possible work-arounds: " "https://github.com/ml-explore/mlx-examples/tree/main/llms#large-models" ) old_limit = mx.metal.set_wired_limit(max_rec_size) try: yield None finally: if streams is not None: for s in streams: mx.synchronize(s) else: mx.synchronize() mx.metal.set_wired_limit(old_limit) def _get_classes(config: dict): """ Retrieve the model and model args classes based on the configuration. Args: config (dict): The model configuration. Returns: A tuple containing the Model class and the ModelArgs class. """ model_type = config["model_type"] model_type = MODEL_REMAPPING.get(model_type, model_type) try: arch = importlib.import_module(f"mlx_lm.models.{model_type}") except ImportError: msg = f"Model type {model_type} not supported." logging.error(msg) raise ValueError(msg) return arch.Model, arch.ModelArgs def compute_bits_per_weight(model): model_bytes = tree_reduce( lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0 ) leaf_modules = tree_flatten( model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) ) model_params = sum(nparams(m) for _, m in leaf_modules) return model_bytes * 8 / model_params def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path: """ Ensures the model is available locally. If the path does not exist locally, it is downloaded from the Hugging Face Hub. Args: path_or_hf_repo (str): The local path or Hugging Face repository ID of the model. revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash. Returns: Path: The path to the model. """ model_path = Path(path_or_hf_repo) if not model_path.exists(): try: model_path = Path( snapshot_download( repo_id=path_or_hf_repo, revision=revision, allow_patterns=[ "*.json", "*.safetensors", "*.py", "tokenizer.model", "*.tiktoken", "*.txt", ], ) ) except: raise ModelNotFoundError( f"Model not found for path or HF repo: {path_or_hf_repo}.\n" "Please make sure you specified the local path or Hugging Face" " repo id correctly.\nIf you are trying to access a private or" " gated Hugging Face repo, make sure you are authenticated:\n" "https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login" ) from None return model_path def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits): if ( kv_bits is not None and not isinstance(prompt_cache[0], cache.QuantizedKVCache) and prompt_cache[0].offset > quantized_kv_start ): for i in range(len(prompt_cache)): if isinstance(prompt_cache[i], cache.KVCache): prompt_cache[i] = prompt_cache[i].to_quantized( group_size=kv_group_size, bits=kv_bits ) def generate_step( prompt: mx.array, model: nn.Module, *, max_tokens: int = 256, sampler: Optional[Callable[mx.array, mx.array]] = None, logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, max_kv_size: Optional[int] = None, prompt_cache: Optional[Any] = None, prefill_step_size: int = 512, kv_bits: Optional[int] = None, kv_group_size: int = 64, quantized_kv_start: int = 0, prompt_progress_callback: Optional[Callable[int, int]] = None, temp: Optional[float] = None, repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = None, top_p: Optional[float] = None, min_p: Optional[float] = None, min_tokens_to_keep: Optional[int] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing token ids based on the given prompt from the model. Args: prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite generator. Default: ``256``. sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a token from a vector of log probabilities. Default: ``None``. logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): A list of functions that take tokens and logits and return the processed logits. Default: ``None``. max_kv_size (int, optional): Maximum size of the key-value cache. Old entries (except the first 4 tokens) will be overwritten. prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if provided, the cache will be updated in place. prefill_step_size (int): Step size for processing the prompt. kv_bits (int, optional): Number of bits to use for KV cache quantization. None implies no cache quantization. Default: ``None``. kv_group_size (int): Group size for KV cache quantization. Default: ``64``. quantized_kv_start (int): Step to begin using a quantized KV cache. when ``kv_bits`` is non-None. Default: ``0``. prompt_prorgress_callback (Callable[int, int]): A call-back which takes the prompt tokens processed so far and the total number of prompt tokens. Yields: Tuple[mx.array, mx.array]: One token and a vector of log probabilities. """ y = prompt tokens = None # Create the KV cache for generation if prompt_cache is None: prompt_cache = cache.make_prompt_cache( model, max_kv_size=max_kv_size, ) elif len(prompt_cache) != len(model.layers): raise ValueError("Wrong number of layers in the prompt cache.") if temp is not None or top_p is not None or min_tokens_to_keep is not None: print( "[Warning] Specifying sampling arguments to ``generate_step`` is " "deprecated. Pass in a ``sampler`` instead." ) if repetition_penalty is not None: print( "[Warning] Specifying ``repetition_penalty`` is deprecated. " "Pass in ``logits_processors`` instead." ) sampler = sampler or make_sampler( temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1 ) logits_processors = logits_processors or make_logits_processors( None, repetition_penalty, repetition_context_size or 20 ) prompt_progress_callback = prompt_progress_callback or (lambda *_: None) def _step(y): with mx.stream(generation_stream): logits = model(y[None], cache=prompt_cache) logits = logits[:, -1, :] if logits_processors: nonlocal tokens tokens = mx.concat([tokens, y]) if tokens is not None else y for processor in logits_processors: logits = processor(tokens, logits) maybe_quantize_kv_cache( prompt_cache, quantized_kv_start, kv_group_size, kv_bits ) logprobs = logits - mx.logsumexp(logits, keepdims=True) y = sampler(logprobs) return y, logprobs.squeeze(0) with mx.stream(generation_stream): total_prompt_tokens = y.size prompt_processed_tokens = 0 while y.size > prefill_step_size: model(y[:prefill_step_size][None], cache=prompt_cache) maybe_quantize_kv_cache( prompt_cache, quantized_kv_start, kv_group_size, kv_bits ) mx.eval([c.state for c in prompt_cache]) prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) prompt_processed_tokens += prefill_step_size y = y[prefill_step_size:] mx.metal.clear_cache() y, logprobs = _step(y) mx.async_eval(y, logprobs) n = 0 while True: if n != max_tokens: next_y, next_logprobs = _step(y) mx.async_eval(next_y, next_logprobs) if n == 0: mx.eval(y) prompt_progress_callback(total_prompt_tokens, total_prompt_tokens) if n == max_tokens: break yield y.item(), logprobs if n % 256 == 0: mx.metal.clear_cache() y, logprobs = next_y, next_logprobs n += 1 def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], prompt: Union[str, mx.array, List[int]], **kwargs, ) -> Generator[GenerationResponse, None, None]: """ A generator producing text based on the given prompt from the model. Args: model (nn.Module): The model to use for generation. tokenizer (PreTrainedTokenizer): The tokenizer. prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. Yields: GenerationResponse: An instance containing the generated text segment and associated metadata. See :class:`GenerationResponse` for details. """ if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) if not isinstance(prompt, mx.array): prompt = mx.array( prompt if isinstance(prompt, list) else tokenizer.encode(prompt) ) detokenizer = tokenizer.detokenizer with wired_limit(model, [generation_stream]): detokenizer.reset() tic = time.perf_counter() for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)): if n == 0: prompt_time = time.perf_counter() - tic prompt_tps = prompt.size / prompt_time tic = time.perf_counter() if token in tokenizer.eos_token_ids: break detokenizer.add_token(token) yield GenerationResponse( text=detokenizer.last_segment, token=token, logprobs=logprobs, prompt_tokens=prompt.size, prompt_tps=prompt_tps, generation_tokens=n + 1, generation_tps=(n + 1) / (time.perf_counter() - tic), peak_memory=mx.metal.get_peak_memory() / 1e9, finish_reason=None, ) detokenizer.finalize() yield GenerationResponse( text=detokenizer.last_segment, token=token, logprobs=logprobs, prompt_tokens=prompt.size, prompt_tps=prompt_tps, generation_tokens=n + 1, generation_tps=(n + 1) / (time.perf_counter() - tic), peak_memory=mx.metal.get_peak_memory() / 1e9, finish_reason="stop" if token in tokenizer.eos_token_ids else "length", ) def generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], prompt: str, verbose: bool = False, formatter: Optional[Callable] = None, stop_strings: Optional[List[str]] = None, **kwargs, ) -> str: """ Generate a complete response from the model. Args: model (nn.Module): The language model. tokenizer (PreTrainedTokenizer): The tokenizer. prompt (str): The string prompt. verbose (bool): If ``True``, print tokens and timing information. Default: ``False``. kwargs: The remaining options get passed to :func:`stream_generate`. See :func:`stream_generate` for more details. """ if formatter is not None: print( "[Warning] Text formatting is deprecated and no longer used. " "The argument will be removed in a future version." ) if verbose: print("=" * 10) print("Prompt:", prompt) text = "" for response in stream_generate(model, tokenizer, prompt, **kwargs): if verbose: print(response.text, end="", flush=True) text += response.text if stop_strings is not None and any(s in text for s in stop_strings): break if verbose: print() print("=" * 10) if len(text) == 0: print("No text generated for this prompt") return print( f"Prompt: {response.prompt_tokens} tokens, " f"{response.prompt_tps:.3f} tokens-per-sec" ) print( f"Generation: {response.generation_tokens} tokens, " f"{response.generation_tps:.3f} tokens-per-sec" ) print(f"Peak memory: {response.peak_memory:.3f} GB") return text def load_config(model_path: Path) -> dict: try: with open(model_path / "config.json", "r") as f: config = json.load(f) except FileNotFoundError: logging.error(f"Config file not found in {model_path}") raise return config def load_model( model_path: Path, lazy: bool = False, model_config: dict = {}, get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes, ) -> nn.Module: """ Load and initialize the model from a given path. Args: model_path (Path): The path to load the model from. lazy (bool): If False eval the model parameters to make sure they are loaded in memory before returning, otherwise they will be loaded when needed. Default: ``False`` model_config (dict, optional): Optional configuration parameters for the model. Defaults to an empty dictionary. get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional): A function that returns the model class and model args class given a config. Defaults to the ``_get_classes`` function. Returns: nn.Module: The loaded and initialized model. Raises: FileNotFoundError: If the weight files (.safetensors) are not found. ValueError: If the model class or args class are not found or cannot be instantiated. """ config = load_config(model_path) config.update(model_config) weight_files = glob.glob(str(model_path / "model*.safetensors")) if not weight_files: # Try weight for back-compat weight_files = glob.glob(str(model_path / "weight*.safetensors")) if not weight_files: logging.error(f"No safetensors found in {model_path}") raise FileNotFoundError(f"No safetensors found in {model_path}") weights = {} for wf in weight_files: weights.update(mx.load(wf)) model_class, model_args_class = get_model_classes(config=config) model_args = model_args_class.from_dict(config) model = model_class(model_args) if hasattr(model, "sanitize"): weights = model.sanitize(weights) if (quantization := config.get("quantization", None)) is not None: def class_predicate(p, m): # Handle custom per layer quantizations if p in config["quantization"]: return config["quantization"][p] if not hasattr(m, "to_quantized"): return False # Handle legacy models which may not have everything quantized return f"{p}.scales" in weights nn.quantize( model, group_size=quantization["group_size"], bits=quantization["bits"], class_predicate=class_predicate, ) model.load_weights(list(weights.items())) if not lazy: mx.eval(model.parameters()) model.eval() return model, config def load( path_or_hf_repo: str, tokenizer_config={}, model_config={}, adapter_path: Optional[str] = None, lazy: bool = False, ) -> Tuple[nn.Module, TokenizerWrapper]: """ Load the model and tokenizer from a given path or a huggingface repository. Args: path_or_hf_repo (Path): The path or the huggingface repository to load the model from. tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer. Defaults to an empty dictionary. model_config(dict, optional): Configuration parameters specifically for the model. Defaults to an empty dictionary. adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers to the model. Default: ``None``. lazy (bool): If False eval the model parameters to make sure they are loaded in memory before returning, otherwise they will be loaded when needed. Default: ``False`` Returns: Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer. Raises: FileNotFoundError: If config file or safetensors are not found. ValueError: If model class or args class are not found. """ model_path = get_model_path(path_or_hf_repo) model, config = load_model(model_path, lazy) if adapter_path is not None: model = load_adapters(model, adapter_path) model.eval() tokenizer = load_tokenizer( model_path, tokenizer_config, eos_token_ids=config.get("eos_token_id", None) ) return model, tokenizer def fetch_from_hub( model_path: Path, lazy: bool = False ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: model, config = load_model(model_path, lazy) tokenizer = load_tokenizer( model_path, eos_token_ids=config.get("eos_token_id", None) ) return model, config, tokenizer def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list: """ Splits the weights into smaller shards. Args: weights (dict): Model weights. max_file_size_gb (int): Maximum size of each shard in gigabytes. Returns: list: List of weight shards. """ max_file_size_bytes = max_file_size_gb << 30 shards = [] shard, shard_size = {}, 0 for k, v in weights.items(): if shard_size + v.nbytes > max_file_size_bytes: shards.append(shard) shard, shard_size = {}, 0 shard[k] = v shard_size += v.nbytes shards.append(shard) return shards def upload_to_hub(path: str, upload_repo: str, hf_path: str): """ Uploads the model to Hugging Face hub. Args: path (str): Local path to the model. upload_repo (str): Name of the HF repo to upload to. hf_path (str): Path to the original Hugging Face model. """ import os from huggingface_hub import HfApi, ModelCard, logging from . import __version__ card = ModelCard.load(hf_path) card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"] card.data.base_model = hf_path card.text = dedent( f""" # {upload_repo} The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**. ## Use with mlx ```bash pip install mlx-lm ``` ```python from mlx_lm import load, generate model, tokenizer = load("{upload_repo}") prompt="hello" if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None: messages = [{{"role": "user", "content": prompt}}] prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) response = generate(model, tokenizer, prompt=prompt, verbose=True) ``` """ ) card.save(os.path.join(path, "README.md")) logging.set_verbosity_info() api = HfApi() api.create_repo(repo_id=upload_repo, exist_ok=True) api.upload_folder( folder_path=path, repo_id=upload_repo, repo_type="model", multi_commits=True, multi_commits_verbose=True, ) print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.") def save_weights( save_path: Union[str, Path], weights: Dict[str, Any], *, donate_weights: bool = False, ) -> None: """Save model weights into specified directory.""" if isinstance(save_path, str): save_path = Path(save_path) save_path.mkdir(parents=True, exist_ok=True) shards = make_shards(weights) shards_count = len(shards) shard_file_format = ( "model-{:05d}-of-{:05d}.safetensors" if shards_count > 1 else "model.safetensors" ) total_size = sum(v.nbytes for v in weights.values()) index_data = {"metadata": {"total_size": total_size}, "weight_map": {}} # Write the weights and make sure no references are kept other than the # necessary ones if donate_weights: weights.clear() del weights for i in range(len(shards)): shard = shards[i] shards[i] = None shard_name = shard_file_format.format(i + 1, shards_count) shard_path = save_path / shard_name mx.save_safetensors(str(shard_path), shard, metadata={"format": "mlx"}) for weight_name in shard.keys(): index_data["weight_map"][weight_name] = shard_name del shard index_data["weight_map"] = { k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"]) } with open(save_path / "model.safetensors.index.json", "w") as f: json.dump( index_data, f, indent=4, ) def quantize_model( model: nn.Module, config: dict, q_group_size: int, q_bits: int, quant_predicate: Optional[ Callable[[str, nn.Module, dict], Union[bool, dict]] ] = None, ) -> Tuple: """ Applies quantization to the model weights. Args: model (nn.Module): The model to be quantized. config (dict): Model configuration. q_group_size (int): Group size for quantization. q_bits (int): Bits per weight for quantization. quant_predicate (Callable): A callable that decides how to quantize each layer based on the path. Accepts the layer `path`, the `module` and the model `config`. Returns either a bool to signify quantize/no quantize or a dict of quantization parameters to pass to `to_quantized`. Returns: Tuple: Tuple containing quantized weights and config. """ quantized_config = copy.deepcopy(config) quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits} # Add any custom quantization parameters to the config as we go def _class_predicate(p, m): bool_or_params = quant_predicate(p, m, config) quantized_config["quantization"][p] = bool_or_params return bool_or_params nn.quantize( model, q_group_size, q_bits, class_predicate=_class_predicate if quant_predicate else None, ) # support hf model tree #957 quantized_config["quantization_config"] = quantized_config["quantization"] quantized_weights = dict(tree_flatten(model.parameters())) bpw = compute_bits_per_weight(model) print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.") return quantized_weights, quantized_config def save_config( config: dict, config_path: Union[str, Path], ) -> None: """Save the model configuration to the ``config_path``. The final configuration will be sorted before saving for better readability. Args: config (dict): The model configuration. config_path (Union[str, Path]): Model configuration file path. """ # Clean unused keys config.pop("_name_or_path", None) # sort the config for better readability config = dict(sorted(config.items())) # write the updated config to the config_path (if provided) with open(config_path, "w") as fid: json.dump(config, fid, indent=4) def convert( hf_path: str, mlx_path: str = "mlx_model", quantize: bool = False, q_group_size: int = 64, q_bits: int = 4, dtype: str = "float16", upload_repo: str = None, revision: Optional[str] = None, dequantize: bool = False, quant_predicate: Optional[ Callable[[str, nn.Module, dict], Union[bool, dict]] ] = None, ): # Check the save path is empty if isinstance(mlx_path, str): mlx_path = Path(mlx_path) if mlx_path.exists(): raise ValueError( f"Cannot save to the path {mlx_path} as it already exists." " Please delete the file/directory or specify a new path to save to." ) print("[INFO] Loading") model_path = get_model_path(hf_path, revision=revision) model, config, tokenizer = fetch_from_hub(model_path, lazy=True) weights = dict(tree_flatten(model.parameters())) dtype = getattr(mx, dtype) weights = {k: v.astype(dtype) for k, v in weights.items()} if quantize and dequantize: raise ValueError("Choose either quantize or dequantize, not both.") if quantize: print("[INFO] Quantizing") model.load_weights(list(weights.items())) weights, config = quantize_model( model, config, q_group_size, q_bits, quant_predicate=quant_predicate ) if dequantize: print("[INFO] Dequantizing") model = dequantize_model(model) weights = dict(tree_flatten(model.parameters())) del model save_weights(mlx_path, weights, donate_weights=True) py_files = glob.glob(str(model_path / "*.py")) for file in py_files: shutil.copy(file, mlx_path) tokenizer.save_pretrained(mlx_path) save_config(config, config_path=mlx_path / "config.json") if upload_repo is not None: upload_to_hub(mlx_path, upload_repo, hf_path) from tqdm import tqdm def generate_batched_response( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], prompt: Union[str, mx.array, List[int]], batch_size: int, max_tokens: int = 256, sampler: Optional[Callable[[mx.array], mx.array]] = None, logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, max_kv_size: Optional[int] = None, prompt_cache: Optional[List[Any]] = None, prefill_step_size: int = 512, kv_bits: Optional[int] = None, kv_group_size: int = 64, quantized_kv_start: int = 0, prompt_progress_callback: Optional[Callable[[int, int], None]] = None, temp: Optional[float] = None, repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = None, top_p: Optional[float] = None, min_p: Optional[float] = None, min_tokens_to_keep: Optional[int] = None, verbose: bool = False, ) -> List[str]: """ Generate multiple responses to the same prompt in parallel and return only the generated sequences (excluding the prompt), stopping at the first EOS token. Args: model (nn.Module): The language model. tokenizer (PreTrainedTokenizer or TokenizerWrapper): The tokenizer. prompt (Union[str, mx.array, List[int]]): The input prompt. batch_size (int): Number of responses to generate in parallel. max_tokens (int): Maximum number of generated tokens per sequence. sampler (Callable): Sampler function. logits_processors (List[Callable]): List of logits processors. max_kv_size (int): Maximum KV cache size. prompt_cache (List[Any]): Precomputed prompt cache. prefill_step_size (int): Step size for prompt processing. kv_bits (int): Bits for KV cache quantization. kv_group_size (int): Group size for KV quantization. quantized_kv_start (int): Step to begin quantizing KV. prompt_progress_callback (Callable): Callback for prompt progress. temp (float): Temperature for sampling (deprecated, pass to sampler). repetition_penalty (float): Repetition penalty (deprecated, use logits_processors). repetition_context_size (int): Context size for repetition. top_p (float): Top-p sampling (deprecated, pass to sampler). min_p (float): Minimum p sampling (deprecated, pass to sampler). min_tokens_to_keep (int): Minimum number of tokens to keep. verbose (bool): If True, show a progress bar for token generation. Returns: List[str]: A list of decoded response strings for each batch element, excluding the prompt and stopping at the first EOS token. """ if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) # Convert prompt to tokens if necessary if not isinstance(prompt, mx.array): prompt = mx.array( prompt if isinstance(prompt, list) else tokenizer.encode(prompt) ) # Expand prompt to batch prompt_length = prompt.size prompt = mx.expand_dims(prompt, 0) # (1, prompt_length) prompt = mx.repeat(prompt, batch_size, axis=0) # (B, prompt_length) B = batch_size if prompt_progress_callback is None: prompt_progress_callback = lambda *_: None if temp is not None or top_p is not None or min_tokens_to_keep is not None: print( "[Warning] Specifying sampling arguments directly is deprecated. " "Pass in a `sampler` if needed." ) if repetition_penalty is not None: print( "[Warning] Specifying `repetition_penalty` is deprecated. " "Use `logits_processors` instead." ) sampler = sampler or make_sampler( temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1 ) logits_processors = logits_processors or make_logits_processors( None, repetition_penalty, repetition_context_size or 20 ) # Create or verify prompt cache if prompt_cache is None: prompt_cache = cache.make_prompt_cache(model, max_kv_size) elif len(prompt_cache) != len(model.layers): raise ValueError("Wrong number of layers in the prompt cache.") # Process the prompt to fill the cache in increments total_prompt_tokens = prompt_length prompt_processed_tokens = 0 remaining_prompt = prompt tic = time.perf_counter() with mx.stream(generation_stream): while remaining_prompt.shape[1] > prefill_step_size: model(remaining_prompt[:, :prefill_step_size], cache=prompt_cache) mx.eval([c.state for c in prompt_cache]) prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) prompt_processed_tokens += prefill_step_size remaining_prompt = remaining_prompt[:, prefill_step_size:] mx.metal.clear_cache() # Process any remaining prompt tokens if remaining_prompt.shape[1] > 0: model(remaining_prompt, cache=prompt_cache) mx.eval([c.state for c in prompt_cache]) prompt_progress_callback(total_prompt_tokens, total_prompt_tokens) prompt_time = time.perf_counter() - tic prompt_tps = (total_prompt_tokens * B) / prompt_time # Initialization for generation tokens = prompt finished = mx.zeros((B,), dtype=tokens.dtype) generation_count = 0 eos_ids = tokenizer.eos_token_ids # Setup progress bar if verbose pbar = None if verbose: if max_tokens >= 0: pbar = tqdm(total=max_tokens, desc="Generating tokens", ncols=80) else: # If we don't have a max_tokens limit, no total is known. # We'll just display a progress bar that counts up. pbar = tqdm(desc="Generating tokens", ncols=80) tic = time.perf_counter() while True: if (max_tokens >= 0) and (generation_count >= max_tokens): break # If all sequences finished, break sum_finished = mx.sum(finished) mx.eval(sum_finished) if sum_finished.item() == B: break # Prepare last token next_input = tokens[:, -1:] # (B,1) with mx.stream(generation_stream): logits = model(next_input, cache=prompt_cache) # logits: (B, 1, vocab) logits = logits[:, -1, :] # (B, vocab) # Apply logits processors if logits_processors: for processor in logits_processors: logits = processor(tokens, logits) maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits) logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) # (B,vocab) sampled_tokens = sampler(logprobs) # (B,) mx.async_eval(sampled_tokens, logprobs) # Check EOS is_eos = mx.zeros_like(sampled_tokens).astype(tokens.dtype) for eid in eos_ids: diff = sampled_tokens - eid sq = diff * diff val = 1.0 / (sq + 1.0) mask = val.astype(tokens.dtype) is_eos = is_eos + mask ones = mx.ones_like(is_eos) is_eos = mx.minimum(is_eos, ones) finished = mx.maximum(finished, is_eos) sampled_tokens = sampled_tokens[:, None] # (B,1) tokens = mx.concatenate([tokens, sampled_tokens], axis=1) generation_count += 1 if pbar is not None: pbar.update(1) if (generation_count % 256) == 0: mx.metal.clear_cache() if pbar is not None: pbar.close() generation_time = time.perf_counter() - tic generation_tps = (generation_count * B) / generation_time if generation_count > 0 else 0.0 peak_memory = mx.metal.get_peak_memory() / 1e9 results = [] for i in range(B): seq = tokens[i][prompt_length:].tolist() # Exclude the prompt # Find the first EOS token eos_pos = None for idx, t in enumerate(seq): if t in eos_ids: eos_pos = idx break # Slice up to EOS if found if eos_pos is not None: seq = seq[:eos_pos] text = tokenizer.decode(seq) results.append(text) if verbose: print("=" * 10) print(f"Prompt: {total_prompt_tokens} tokens * {B} sequences, {prompt_tps:.3f} tps") print( f"Generation: {generation_count} tokens * {B} sequences, " f"{generation_tps:.3f} tps" ) print(f"Peak memory: {peak_memory:.3f} GB") return results