mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-04 15:54:34 +08:00
implement batch_generate
This commit is contained in:
@@ -21,6 +21,7 @@ from transformers import PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
from .models import cache
|
||||
from .models.base import create_causal_mask
|
||||
from .sample_utils import make_logits_processors, make_sampler
|
||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||
from .tuner.utils import dequantize as dequantize_model
|
||||
@@ -355,6 +356,7 @@ def generate_step(
|
||||
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
||||
if n == max_tokens:
|
||||
break
|
||||
mx.eval(y)
|
||||
yield y, logprobs
|
||||
if n % 256 == 0:
|
||||
mx.metal.clear_cache()
|
||||
@@ -488,8 +490,85 @@ def generate(
|
||||
return text
|
||||
|
||||
|
||||
def batch_generate():
|
||||
pass
|
||||
def batch_generate(
|
||||
model: nn.Module,
|
||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||
prompts: list[str],
|
||||
verbose: bool = False,
|
||||
**kwargs,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Generate complete responses from the model for a list of prompts.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The language model.
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||
prompts (List[str]): The string prompts.
|
||||
verbose (bool): If ``True``, print tokens and timing information.
|
||||
Default: ``False``.
|
||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||
See :func:`generate_step` for more details.
|
||||
"""
|
||||
if 'prompt_cache' in kwargs:
|
||||
# TODO: Handle `prompt_cache` and `prompt` both left-padded, so that
|
||||
# we have <pad>text<pad>text. Should involve taking `prompt_cache_lens`
|
||||
# to extend `mask` below, and handling position_ids (see TODO below)
|
||||
raise ValueError("Batch generation does not support prompt_cache yet.")
|
||||
if not isinstance(tokenizer, TokenizerWrapper):
|
||||
tokenizer = TokenizerWrapper(tokenizer)
|
||||
# TODO: left-shift position_ids for absolute/rotary positional encodings
|
||||
# Example: https://github.com/huggingface/transformers/issues/26072#issuecomment-2101209470
|
||||
tokenizer._tokenizer.padding_side = "left"
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer._tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
res = tokenizer._tokenizer(prompts, padding=True)
|
||||
input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"])
|
||||
causal_mask = create_causal_mask(token_mask.shape[-1])
|
||||
mask = mx.where(token_mask[:, None, None, :], causal_mask, -1e9)
|
||||
|
||||
output_toks = []
|
||||
prompt_time = None
|
||||
ended = mx.zeros(len(prompts), dtype=mx.bool_)
|
||||
tic = time.perf_counter()
|
||||
# TODO: non-generator version of `generate_step` so that we can
|
||||
# add or remove prompts from the batch as they start/finish
|
||||
for tokens, _ in generate_step(input_ids, model, mask=mask, **kwargs):
|
||||
if not prompt_time:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
tic = time.perf_counter()
|
||||
ended = ended | (tokens == tokenizer.eos_token_id)
|
||||
if ended.all():
|
||||
break
|
||||
output_toks.append(tokens)
|
||||
if verbose:
|
||||
print(".", end="", flush=True)
|
||||
output_toks = mx.stack(output_toks, axis=-1)
|
||||
token_count = output_toks.size
|
||||
response = [
|
||||
response.split(tokenizer.eos_token)[0].split(tokenizer.pad_token)[0]
|
||||
for response in tokenizer.batch_decode(output_toks.tolist())
|
||||
]
|
||||
if verbose:
|
||||
gen_time = time.perf_counter() - tic
|
||||
if token_count <= 0:
|
||||
print("No tokens generated for this prompt")
|
||||
else:
|
||||
print()
|
||||
for p, resp in zip(prompts, response):
|
||||
print("=" * 10)
|
||||
print("Prompt:", p)
|
||||
print(resp)
|
||||
print("=" * 10)
|
||||
if prompt_time:
|
||||
prompt_tps = input_ids.size / prompt_time
|
||||
print(f"Prompt: {input_ids.size} tokens, {prompt_tps:.3f} tokens-per-sec")
|
||||
gen_tps = token_count / gen_time
|
||||
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||
print(f"Peak memory: {peak_mem:.3f} GB")
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def load_config(model_path: Path) -> dict:
|
||||
|
Reference in New Issue
Block a user