mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 12:26:07 +08:00
implement batch_generate
This commit is contained in:
parent
cded14988c
commit
465eb79fff
@ -21,6 +21,7 @@ from transformers import PreTrainedTokenizer
|
|||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
from .models import cache
|
from .models import cache
|
||||||
|
from .models.base import create_causal_mask
|
||||||
from .sample_utils import make_logits_processors, make_sampler
|
from .sample_utils import make_logits_processors, make_sampler
|
||||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||||
from .tuner.utils import dequantize as dequantize_model
|
from .tuner.utils import dequantize as dequantize_model
|
||||||
@ -355,6 +356,7 @@ def generate_step(
|
|||||||
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
||||||
if n == max_tokens:
|
if n == max_tokens:
|
||||||
break
|
break
|
||||||
|
mx.eval(y)
|
||||||
yield y, logprobs
|
yield y, logprobs
|
||||||
if n % 256 == 0:
|
if n % 256 == 0:
|
||||||
mx.metal.clear_cache()
|
mx.metal.clear_cache()
|
||||||
@ -488,8 +490,85 @@ def generate(
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def batch_generate():
|
def batch_generate(
|
||||||
pass
|
model: nn.Module,
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||||
|
prompts: list[str],
|
||||||
|
verbose: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> list[str]:
|
||||||
|
"""
|
||||||
|
Generate complete responses from the model for a list of prompts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The language model.
|
||||||
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||||
|
prompts (List[str]): The string prompts.
|
||||||
|
verbose (bool): If ``True``, print tokens and timing information.
|
||||||
|
Default: ``False``.
|
||||||
|
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||||
|
See :func:`generate_step` for more details.
|
||||||
|
"""
|
||||||
|
if 'prompt_cache' in kwargs:
|
||||||
|
# TODO: Handle `prompt_cache` and `prompt` both left-padded, so that
|
||||||
|
# we have <pad>text<pad>text. Should involve taking `prompt_cache_lens`
|
||||||
|
# to extend `mask` below, and handling position_ids (see TODO below)
|
||||||
|
raise ValueError("Batch generation does not support prompt_cache yet.")
|
||||||
|
if not isinstance(tokenizer, TokenizerWrapper):
|
||||||
|
tokenizer = TokenizerWrapper(tokenizer)
|
||||||
|
# TODO: left-shift position_ids for absolute/rotary positional encodings
|
||||||
|
# Example: https://github.com/huggingface/transformers/issues/26072#issuecomment-2101209470
|
||||||
|
tokenizer._tokenizer.padding_side = "left"
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer._tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
res = tokenizer._tokenizer(prompts, padding=True)
|
||||||
|
input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"])
|
||||||
|
causal_mask = create_causal_mask(token_mask.shape[-1])
|
||||||
|
mask = mx.where(token_mask[:, None, None, :], causal_mask, -1e9)
|
||||||
|
|
||||||
|
output_toks = []
|
||||||
|
prompt_time = None
|
||||||
|
ended = mx.zeros(len(prompts), dtype=mx.bool_)
|
||||||
|
tic = time.perf_counter()
|
||||||
|
# TODO: non-generator version of `generate_step` so that we can
|
||||||
|
# add or remove prompts from the batch as they start/finish
|
||||||
|
for tokens, _ in generate_step(input_ids, model, mask=mask, **kwargs):
|
||||||
|
if not prompt_time:
|
||||||
|
prompt_time = time.perf_counter() - tic
|
||||||
|
tic = time.perf_counter()
|
||||||
|
ended = ended | (tokens == tokenizer.eos_token_id)
|
||||||
|
if ended.all():
|
||||||
|
break
|
||||||
|
output_toks.append(tokens)
|
||||||
|
if verbose:
|
||||||
|
print(".", end="", flush=True)
|
||||||
|
output_toks = mx.stack(output_toks, axis=-1)
|
||||||
|
token_count = output_toks.size
|
||||||
|
response = [
|
||||||
|
response.split(tokenizer.eos_token)[0].split(tokenizer.pad_token)[0]
|
||||||
|
for response in tokenizer.batch_decode(output_toks.tolist())
|
||||||
|
]
|
||||||
|
if verbose:
|
||||||
|
gen_time = time.perf_counter() - tic
|
||||||
|
if token_count <= 0:
|
||||||
|
print("No tokens generated for this prompt")
|
||||||
|
else:
|
||||||
|
print()
|
||||||
|
for p, resp in zip(prompts, response):
|
||||||
|
print("=" * 10)
|
||||||
|
print("Prompt:", p)
|
||||||
|
print(resp)
|
||||||
|
print("=" * 10)
|
||||||
|
if prompt_time:
|
||||||
|
prompt_tps = input_ids.size / prompt_time
|
||||||
|
print(f"Prompt: {input_ids.size} tokens, {prompt_tps:.3f} tokens-per-sec")
|
||||||
|
gen_tps = token_count / gen_time
|
||||||
|
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
||||||
|
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||||
|
print(f"Peak memory: {peak_mem:.3f} GB")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
def load_config(model_path: Path) -> dict:
|
def load_config(model_path: Path) -> dict:
|
||||||
|
@ -2,12 +2,11 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from mlx_lm.sample_utils import make_logits_processors
|
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||||
from mlx_lm.utils import generate, load
|
from mlx_lm.utils import generate, batch_generate, load
|
||||||
|
|
||||||
|
|
||||||
class TestGenerate(unittest.TestCase):
|
class TestGenerate(unittest.TestCase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||||
@ -51,6 +50,23 @@ class TestGenerate(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(all_toks.shape[-1], len(init_toks) + 5)
|
self.assertEqual(all_toks.shape[-1], len(init_toks) + 5)
|
||||||
|
|
||||||
|
def test_batch_generate(self):
|
||||||
|
logit_bias = {0: 20.0, 1: -20.0}
|
||||||
|
texts = batch_generate(
|
||||||
|
self.model,
|
||||||
|
self.tokenizer,
|
||||||
|
[
|
||||||
|
"hello",
|
||||||
|
"this is a longer prompt to test out the padding and masking. hello",
|
||||||
|
],
|
||||||
|
max_tokens=5,
|
||||||
|
prefill_step_size=4,
|
||||||
|
sampler=make_sampler(temp=1.0, min_p=0.1),
|
||||||
|
logits_processors=make_logits_processors(logit_bias, repetition_penalty=2.0),
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
self.assertEqual(texts, ['!', '!'])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user