implement batch_generate

This commit is contained in:
L Lllvvuu 2024-12-27 13:47:09 -08:00
parent cded14988c
commit 465eb79fff
No known key found for this signature in database
GPG Key ID: CFAD5A25056DDD0F
2 changed files with 100 additions and 5 deletions

View File

@ -21,6 +21,7 @@ from transformers import PreTrainedTokenizer
# Local imports
from .models import cache
from .models.base import create_causal_mask
from .sample_utils import make_logits_processors, make_sampler
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import dequantize as dequantize_model
@ -355,6 +356,7 @@ def generate_step(
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
if n == max_tokens:
break
mx.eval(y)
yield y, logprobs
if n % 256 == 0:
mx.metal.clear_cache()
@ -488,8 +490,85 @@ def generate(
return text
def batch_generate():
pass
def batch_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompts: list[str],
verbose: bool = False,
**kwargs,
) -> list[str]:
"""
Generate complete responses from the model for a list of prompts.
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompts (List[str]): The string prompts.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
"""
if 'prompt_cache' in kwargs:
# TODO: Handle `prompt_cache` and `prompt` both left-padded, so that
# we have <pad>text<pad>text. Should involve taking `prompt_cache_lens`
# to extend `mask` below, and handling position_ids (see TODO below)
raise ValueError("Batch generation does not support prompt_cache yet.")
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
# TODO: left-shift position_ids for absolute/rotary positional encodings
# Example: https://github.com/huggingface/transformers/issues/26072#issuecomment-2101209470
tokenizer._tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
tokenizer._tokenizer.pad_token = tokenizer.eos_token
tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id
res = tokenizer._tokenizer(prompts, padding=True)
input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"])
causal_mask = create_causal_mask(token_mask.shape[-1])
mask = mx.where(token_mask[:, None, None, :], causal_mask, -1e9)
output_toks = []
prompt_time = None
ended = mx.zeros(len(prompts), dtype=mx.bool_)
tic = time.perf_counter()
# TODO: non-generator version of `generate_step` so that we can
# add or remove prompts from the batch as they start/finish
for tokens, _ in generate_step(input_ids, model, mask=mask, **kwargs):
if not prompt_time:
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
ended = ended | (tokens == tokenizer.eos_token_id)
if ended.all():
break
output_toks.append(tokens)
if verbose:
print(".", end="", flush=True)
output_toks = mx.stack(output_toks, axis=-1)
token_count = output_toks.size
response = [
response.split(tokenizer.eos_token)[0].split(tokenizer.pad_token)[0]
for response in tokenizer.batch_decode(output_toks.tolist())
]
if verbose:
gen_time = time.perf_counter() - tic
if token_count <= 0:
print("No tokens generated for this prompt")
else:
print()
for p, resp in zip(prompts, response):
print("=" * 10)
print("Prompt:", p)
print(resp)
print("=" * 10)
if prompt_time:
prompt_tps = input_ids.size / prompt_time
print(f"Prompt: {input_ids.size} tokens, {prompt_tps:.3f} tokens-per-sec")
gen_tps = token_count / gen_time
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 2**30
print(f"Peak memory: {peak_mem:.3f} GB")
return response
def load_config(model_path: Path) -> dict:

View File

@ -2,12 +2,11 @@
import unittest
from mlx_lm.sample_utils import make_logits_processors
from mlx_lm.utils import generate, load
from mlx_lm.sample_utils import make_logits_processors, make_sampler
from mlx_lm.utils import generate, batch_generate, load
class TestGenerate(unittest.TestCase):
@classmethod
def setUpClass(cls):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
@ -51,6 +50,23 @@ class TestGenerate(unittest.TestCase):
)
self.assertEqual(all_toks.shape[-1], len(init_toks) + 5)
def test_batch_generate(self):
logit_bias = {0: 20.0, 1: -20.0}
texts = batch_generate(
self.model,
self.tokenizer,
[
"hello",
"this is a longer prompt to test out the padding and masking. hello",
],
max_tokens=5,
prefill_step_size=4,
sampler=make_sampler(temp=1.0, min_p=0.1),
logits_processors=make_logits_processors(logit_bias, repetition_penalty=2.0),
verbose=False,
)
self.assertEqual(texts, ['!', '!'])
if __name__ == "__main__":
unittest.main()