mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-04 15:54:34 +08:00
Allow prompt callback to generate_step
(#1133)
* allow prompt callback and use in cache_prompt * nit * comments * bump version
This commit is contained in:

committed by
Billel Mokeddem

parent
a73de93247
commit
e08c470d29
@@ -183,6 +183,7 @@ def generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
*,
|
||||
max_tokens: int = 256,
|
||||
sampler: Optional[Callable[mx.array, mx.array]] = None,
|
||||
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
max_kv_size: Optional[int] = None,
|
||||
@@ -191,6 +192,7 @@ def generate_step(
|
||||
kv_bits: Optional[int] = None,
|
||||
kv_group_size: int = 64,
|
||||
quantized_kv_start: int = 0,
|
||||
prompt_progress_callback: Optional[Callable[int, int]] = None,
|
||||
temp: Optional[float] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = None,
|
||||
@@ -204,21 +206,25 @@ def generate_step(
|
||||
Args:
|
||||
prompt (mx.array): The input prompt.
|
||||
model (nn.Module): The model to use for generation.
|
||||
prefill_step_size (int): Step size for processing the prompt.
|
||||
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||
entries (except the first 4 tokens) will be overwritten.
|
||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||
provided, the cache will be updated in place.
|
||||
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
|
||||
generator. Default: ``256``.
|
||||
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
|
||||
token from a vector of log probabilities. Default: ``None``.
|
||||
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||
A list of functions that take tokens and logits and return the processed
|
||||
logits. Default: ``None``.
|
||||
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||
entries (except the first 4 tokens) will be overwritten.
|
||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||
provided, the cache will be updated in place.
|
||||
prefill_step_size (int): Step size for processing the prompt.
|
||||
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
||||
None implies no cache quantization. Default: ``None``.
|
||||
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
||||
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
||||
when ``kv_bits`` is non-None. Default: ``0``.
|
||||
prompt_prorgress_callback (Callable[int, int]): A call-back which takes the
|
||||
prompt tokens processed so far and the total number of prompt tokens.
|
||||
|
||||
Yields:
|
||||
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
||||
@@ -253,6 +259,7 @@ def generate_step(
|
||||
logits_processors = logits_processors or make_logits_processors(
|
||||
None, repetition_penalty, repetition_context_size or 20
|
||||
)
|
||||
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
|
||||
|
||||
def _step(y):
|
||||
with mx.stream(generation_stream):
|
||||
@@ -275,9 +282,13 @@ def generate_step(
|
||||
return y, logprobs.squeeze(0)
|
||||
|
||||
with mx.stream(generation_stream):
|
||||
total_prompt_tokens = y.size
|
||||
prompt_processed_tokens = 0
|
||||
while y.size > prefill_step_size:
|
||||
model(y[:prefill_step_size][None], cache=prompt_cache)
|
||||
mx.eval([c.state for c in prompt_cache])
|
||||
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
|
||||
prompt_processed_tokens += prefill_step_size
|
||||
y = y[prefill_step_size:]
|
||||
mx.metal.clear_cache()
|
||||
|
||||
@@ -286,20 +297,25 @@ def generate_step(
|
||||
mx.async_eval(y, logprobs)
|
||||
n = 0
|
||||
while True:
|
||||
next_y, next_logprobs = _step(y)
|
||||
mx.async_eval(next_y, next_logprobs)
|
||||
if n != max_tokens:
|
||||
next_y, next_logprobs = _step(y)
|
||||
mx.async_eval(next_y, next_logprobs)
|
||||
if n == 0:
|
||||
mx.eval(y)
|
||||
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
||||
if n == max_tokens:
|
||||
break
|
||||
yield y.item(), logprobs
|
||||
if n % 256 == 0:
|
||||
mx.metal.clear_cache()
|
||||
n += 1
|
||||
y, logprobs = next_y, next_logprobs
|
||||
n += 1
|
||||
|
||||
|
||||
def stream_generate(
|
||||
model: nn.Module,
|
||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||
prompt: Union[str, mx.array, List[int]],
|
||||
max_tokens: int = 100,
|
||||
**kwargs,
|
||||
) -> Generator[GenerationResponse, None, None]:
|
||||
"""
|
||||
@@ -309,7 +325,6 @@ def stream_generate(
|
||||
model (nn.Module): The model to use for generation.
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||
prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens.
|
||||
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||
See :func:`generate_step` for more details.
|
||||
|
||||
@@ -330,10 +345,7 @@ def stream_generate(
|
||||
with wired_limit(model, [generation_stream]):
|
||||
detokenizer.reset()
|
||||
tic = time.perf_counter()
|
||||
for n, (token, logprobs) in zip(
|
||||
range(max_tokens),
|
||||
generate_step(prompt, model, **kwargs),
|
||||
):
|
||||
for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)):
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
prompt_tps = prompt.size / prompt_time
|
||||
@@ -343,9 +355,6 @@ def stream_generate(
|
||||
|
||||
detokenizer.add_token(token)
|
||||
|
||||
if n == (max_tokens - 1):
|
||||
break
|
||||
|
||||
yield GenerationResponse(
|
||||
text=detokenizer.last_segment,
|
||||
token=token,
|
||||
@@ -385,7 +394,6 @@ def generate(
|
||||
model (nn.Module): The language model.
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||
prompt (str): The string prompt.
|
||||
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
||||
verbose (bool): If ``True``, print tokens and timing information.
|
||||
Default: ``False``.
|
||||
kwargs: The remaining options get passed to :func:`stream_generate`.
|
||||
|
Reference in New Issue
Block a user