mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 03:01:34 +08:00
implement batch_generate
This commit is contained in:
parent
cded14988c
commit
465eb79fff
@ -21,6 +21,7 @@ from transformers import PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
from .models import cache
|
||||
from .models.base import create_causal_mask
|
||||
from .sample_utils import make_logits_processors, make_sampler
|
||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||
from .tuner.utils import dequantize as dequantize_model
|
||||
@ -355,6 +356,7 @@ def generate_step(
|
||||
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
||||
if n == max_tokens:
|
||||
break
|
||||
mx.eval(y)
|
||||
yield y, logprobs
|
||||
if n % 256 == 0:
|
||||
mx.metal.clear_cache()
|
||||
@ -488,8 +490,85 @@ def generate(
|
||||
return text
|
||||
|
||||
|
||||
def batch_generate():
|
||||
pass
|
||||
def batch_generate(
|
||||
model: nn.Module,
|
||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||
prompts: list[str],
|
||||
verbose: bool = False,
|
||||
**kwargs,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Generate complete responses from the model for a list of prompts.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The language model.
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||
prompts (List[str]): The string prompts.
|
||||
verbose (bool): If ``True``, print tokens and timing information.
|
||||
Default: ``False``.
|
||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||
See :func:`generate_step` for more details.
|
||||
"""
|
||||
if 'prompt_cache' in kwargs:
|
||||
# TODO: Handle `prompt_cache` and `prompt` both left-padded, so that
|
||||
# we have <pad>text<pad>text. Should involve taking `prompt_cache_lens`
|
||||
# to extend `mask` below, and handling position_ids (see TODO below)
|
||||
raise ValueError("Batch generation does not support prompt_cache yet.")
|
||||
if not isinstance(tokenizer, TokenizerWrapper):
|
||||
tokenizer = TokenizerWrapper(tokenizer)
|
||||
# TODO: left-shift position_ids for absolute/rotary positional encodings
|
||||
# Example: https://github.com/huggingface/transformers/issues/26072#issuecomment-2101209470
|
||||
tokenizer._tokenizer.padding_side = "left"
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer._tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
res = tokenizer._tokenizer(prompts, padding=True)
|
||||
input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"])
|
||||
causal_mask = create_causal_mask(token_mask.shape[-1])
|
||||
mask = mx.where(token_mask[:, None, None, :], causal_mask, -1e9)
|
||||
|
||||
output_toks = []
|
||||
prompt_time = None
|
||||
ended = mx.zeros(len(prompts), dtype=mx.bool_)
|
||||
tic = time.perf_counter()
|
||||
# TODO: non-generator version of `generate_step` so that we can
|
||||
# add or remove prompts from the batch as they start/finish
|
||||
for tokens, _ in generate_step(input_ids, model, mask=mask, **kwargs):
|
||||
if not prompt_time:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
tic = time.perf_counter()
|
||||
ended = ended | (tokens == tokenizer.eos_token_id)
|
||||
if ended.all():
|
||||
break
|
||||
output_toks.append(tokens)
|
||||
if verbose:
|
||||
print(".", end="", flush=True)
|
||||
output_toks = mx.stack(output_toks, axis=-1)
|
||||
token_count = output_toks.size
|
||||
response = [
|
||||
response.split(tokenizer.eos_token)[0].split(tokenizer.pad_token)[0]
|
||||
for response in tokenizer.batch_decode(output_toks.tolist())
|
||||
]
|
||||
if verbose:
|
||||
gen_time = time.perf_counter() - tic
|
||||
if token_count <= 0:
|
||||
print("No tokens generated for this prompt")
|
||||
else:
|
||||
print()
|
||||
for p, resp in zip(prompts, response):
|
||||
print("=" * 10)
|
||||
print("Prompt:", p)
|
||||
print(resp)
|
||||
print("=" * 10)
|
||||
if prompt_time:
|
||||
prompt_tps = input_ids.size / prompt_time
|
||||
print(f"Prompt: {input_ids.size} tokens, {prompt_tps:.3f} tokens-per-sec")
|
||||
gen_tps = token_count / gen_time
|
||||
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||
print(f"Peak memory: {peak_mem:.3f} GB")
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def load_config(model_path: Path) -> dict:
|
||||
|
@ -2,12 +2,11 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from mlx_lm.sample_utils import make_logits_processors
|
||||
from mlx_lm.utils import generate, load
|
||||
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||
from mlx_lm.utils import generate, batch_generate, load
|
||||
|
||||
|
||||
class TestGenerate(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
@ -51,6 +50,23 @@ class TestGenerate(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(all_toks.shape[-1], len(init_toks) + 5)
|
||||
|
||||
def test_batch_generate(self):
|
||||
logit_bias = {0: 20.0, 1: -20.0}
|
||||
texts = batch_generate(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
[
|
||||
"hello",
|
||||
"this is a longer prompt to test out the padding and masking. hello",
|
||||
],
|
||||
max_tokens=5,
|
||||
prefill_step_size=4,
|
||||
sampler=make_sampler(temp=1.0, min_p=0.1),
|
||||
logits_processors=make_logits_processors(logit_bias, repetition_penalty=2.0),
|
||||
verbose=False,
|
||||
)
|
||||
self.assertEqual(texts, ['!', '!'])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user