implement batch_generate

This commit is contained in:
L Lllvvuu
2024-12-27 13:47:09 -08:00
parent cded14988c
commit 465eb79fff
2 changed files with 100 additions and 5 deletions

View File

@@ -21,6 +21,7 @@ from transformers import PreTrainedTokenizer
# Local imports
from .models import cache
from .models.base import create_causal_mask
from .sample_utils import make_logits_processors, make_sampler
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import dequantize as dequantize_model
@@ -355,6 +356,7 @@ def generate_step(
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
if n == max_tokens:
break
mx.eval(y)
yield y, logprobs
if n % 256 == 0:
mx.metal.clear_cache()
@@ -488,8 +490,85 @@ def generate(
return text
def batch_generate():
pass
def batch_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompts: list[str],
verbose: bool = False,
**kwargs,
) -> list[str]:
"""
Generate complete responses from the model for a list of prompts.
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompts (List[str]): The string prompts.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
"""
if 'prompt_cache' in kwargs:
# TODO: Handle `prompt_cache` and `prompt` both left-padded, so that
# we have <pad>text<pad>text. Should involve taking `prompt_cache_lens`
# to extend `mask` below, and handling position_ids (see TODO below)
raise ValueError("Batch generation does not support prompt_cache yet.")
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
# TODO: left-shift position_ids for absolute/rotary positional encodings
# Example: https://github.com/huggingface/transformers/issues/26072#issuecomment-2101209470
tokenizer._tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
tokenizer._tokenizer.pad_token = tokenizer.eos_token
tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id
res = tokenizer._tokenizer(prompts, padding=True)
input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"])
causal_mask = create_causal_mask(token_mask.shape[-1])
mask = mx.where(token_mask[:, None, None, :], causal_mask, -1e9)
output_toks = []
prompt_time = None
ended = mx.zeros(len(prompts), dtype=mx.bool_)
tic = time.perf_counter()
# TODO: non-generator version of `generate_step` so that we can
# add or remove prompts from the batch as they start/finish
for tokens, _ in generate_step(input_ids, model, mask=mask, **kwargs):
if not prompt_time:
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
ended = ended | (tokens == tokenizer.eos_token_id)
if ended.all():
break
output_toks.append(tokens)
if verbose:
print(".", end="", flush=True)
output_toks = mx.stack(output_toks, axis=-1)
token_count = output_toks.size
response = [
response.split(tokenizer.eos_token)[0].split(tokenizer.pad_token)[0]
for response in tokenizer.batch_decode(output_toks.tolist())
]
if verbose:
gen_time = time.perf_counter() - tic
if token_count <= 0:
print("No tokens generated for this prompt")
else:
print()
for p, resp in zip(prompts, response):
print("=" * 10)
print("Prompt:", p)
print(resp)
print("=" * 10)
if prompt_time:
prompt_tps = input_ids.size / prompt_time
print(f"Prompt: {input_ids.size} tokens, {prompt_tps:.3f} tokens-per-sec")
gen_tps = token_count / gen_time
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 2**30
print(f"Peak memory: {peak_mem:.3f} GB")
return response
def load_config(model_path: Path) -> dict: