diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index b4f7728d..f28fd830 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -21,6 +21,7 @@ from transformers import PreTrainedTokenizer # Local imports from .models import cache +from .models.base import create_causal_mask from .sample_utils import make_logits_processors, make_sampler from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import dequantize as dequantize_model @@ -355,6 +356,7 @@ def generate_step( prompt_progress_callback(total_prompt_tokens, total_prompt_tokens) if n == max_tokens: break + mx.eval(y) yield y, logprobs if n % 256 == 0: mx.metal.clear_cache() @@ -488,8 +490,85 @@ def generate( return text -def batch_generate(): - pass +def batch_generate( + model: nn.Module, + tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], + prompts: list[str], + verbose: bool = False, + **kwargs, +) -> list[str]: + """ + Generate complete responses from the model for a list of prompts. + + Args: + model (nn.Module): The language model. + tokenizer (PreTrainedTokenizer): The tokenizer. + prompts (List[str]): The string prompts. + verbose (bool): If ``True``, print tokens and timing information. + Default: ``False``. + kwargs: The remaining options get passed to :func:`generate_step`. + See :func:`generate_step` for more details. + """ + if 'prompt_cache' in kwargs: + # TODO: Handle `prompt_cache` and `prompt` both left-padded, so that + # we have texttext. Should involve taking `prompt_cache_lens` + # to extend `mask` below, and handling position_ids (see TODO below) + raise ValueError("Batch generation does not support prompt_cache yet.") + if not isinstance(tokenizer, TokenizerWrapper): + tokenizer = TokenizerWrapper(tokenizer) + # TODO: left-shift position_ids for absolute/rotary positional encodings + # Example: https://github.com/huggingface/transformers/issues/26072#issuecomment-2101209470 + tokenizer._tokenizer.padding_side = "left" + if tokenizer.pad_token is None: + tokenizer._tokenizer.pad_token = tokenizer.eos_token + tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id + res = tokenizer._tokenizer(prompts, padding=True) + input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"]) + causal_mask = create_causal_mask(token_mask.shape[-1]) + mask = mx.where(token_mask[:, None, None, :], causal_mask, -1e9) + + output_toks = [] + prompt_time = None + ended = mx.zeros(len(prompts), dtype=mx.bool_) + tic = time.perf_counter() + # TODO: non-generator version of `generate_step` so that we can + # add or remove prompts from the batch as they start/finish + for tokens, _ in generate_step(input_ids, model, mask=mask, **kwargs): + if not prompt_time: + prompt_time = time.perf_counter() - tic + tic = time.perf_counter() + ended = ended | (tokens == tokenizer.eos_token_id) + if ended.all(): + break + output_toks.append(tokens) + if verbose: + print(".", end="", flush=True) + output_toks = mx.stack(output_toks, axis=-1) + token_count = output_toks.size + response = [ + response.split(tokenizer.eos_token)[0].split(tokenizer.pad_token)[0] + for response in tokenizer.batch_decode(output_toks.tolist()) + ] + if verbose: + gen_time = time.perf_counter() - tic + if token_count <= 0: + print("No tokens generated for this prompt") + else: + print() + for p, resp in zip(prompts, response): + print("=" * 10) + print("Prompt:", p) + print(resp) + print("=" * 10) + if prompt_time: + prompt_tps = input_ids.size / prompt_time + print(f"Prompt: {input_ids.size} tokens, {prompt_tps:.3f} tokens-per-sec") + gen_tps = token_count / gen_time + print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") + peak_mem = mx.metal.get_peak_memory() / 2**30 + print(f"Peak memory: {peak_mem:.3f} GB") + + return response def load_config(model_path: Path) -> dict: diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py index b069edef..14fa75e9 100644 --- a/llms/tests/test_generate.py +++ b/llms/tests/test_generate.py @@ -2,12 +2,11 @@ import unittest -from mlx_lm.sample_utils import make_logits_processors -from mlx_lm.utils import generate, load +from mlx_lm.sample_utils import make_logits_processors, make_sampler +from mlx_lm.utils import generate, batch_generate, load class TestGenerate(unittest.TestCase): - @classmethod def setUpClass(cls): HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" @@ -51,6 +50,23 @@ class TestGenerate(unittest.TestCase): ) self.assertEqual(all_toks.shape[-1], len(init_toks) + 5) + def test_batch_generate(self): + logit_bias = {0: 20.0, 1: -20.0} + texts = batch_generate( + self.model, + self.tokenizer, + [ + "hello", + "this is a longer prompt to test out the padding and masking. hello", + ], + max_tokens=5, + prefill_step_size=4, + sampler=make_sampler(temp=1.0, min_p=0.1), + logits_processors=make_logits_processors(logit_bias, repetition_penalty=2.0), + verbose=False, + ) + self.assertEqual(texts, ['!', '!']) + if __name__ == "__main__": unittest.main()