Allow prompt callback to generate_step (#1133)

* allow prompt callback and use in cache_prompt

* nit

* comments

* bump version
This commit is contained in:
Awni Hannun
2024-12-03 16:17:14 -08:00
committed by Billel Mokeddem
parent a73de93247
commit e08c470d29
5 changed files with 48 additions and 48 deletions

View File

@@ -183,6 +183,7 @@ def generate_step(
prompt: mx.array,
model: nn.Module,
*,
max_tokens: int = 256,
sampler: Optional[Callable[mx.array, mx.array]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
max_kv_size: Optional[int] = None,
@@ -191,6 +192,7 @@ def generate_step(
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
prompt_progress_callback: Optional[Callable[int, int]] = None,
temp: Optional[float] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None,
@@ -204,21 +206,25 @@ def generate_step(
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
generator. Default: ``256``.
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
token from a vector of log probabilities. Default: ``None``.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
prefill_step_size (int): Step size for processing the prompt.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
prompt_prorgress_callback (Callable[int, int]): A call-back which takes the
prompt tokens processed so far and the total number of prompt tokens.
Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
@@ -253,6 +259,7 @@ def generate_step(
logits_processors = logits_processors or make_logits_processors(
None, repetition_penalty, repetition_context_size or 20
)
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
def _step(y):
with mx.stream(generation_stream):
@@ -275,9 +282,13 @@ def generate_step(
return y, logprobs.squeeze(0)
with mx.stream(generation_stream):
total_prompt_tokens = y.size
prompt_processed_tokens = 0
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=prompt_cache)
mx.eval([c.state for c in prompt_cache])
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
prompt_processed_tokens += prefill_step_size
y = y[prefill_step_size:]
mx.metal.clear_cache()
@@ -286,20 +297,25 @@ def generate_step(
mx.async_eval(y, logprobs)
n = 0
while True:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs)
if n != max_tokens:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs)
if n == 0:
mx.eval(y)
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
if n == max_tokens:
break
yield y.item(), logprobs
if n % 256 == 0:
mx.metal.clear_cache()
n += 1
y, logprobs = next_y, next_logprobs
n += 1
def stream_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: Union[str, mx.array, List[int]],
max_tokens: int = 100,
**kwargs,
) -> Generator[GenerationResponse, None, None]:
"""
@@ -309,7 +325,6 @@ def stream_generate(
model (nn.Module): The model to use for generation.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens.
max_tokens (int): The maximum number of tokens. Default: ``100``.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
@@ -330,10 +345,7 @@ def stream_generate(
with wired_limit(model, [generation_stream]):
detokenizer.reset()
tic = time.perf_counter()
for n, (token, logprobs) in zip(
range(max_tokens),
generate_step(prompt, model, **kwargs),
):
for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)):
if n == 0:
prompt_time = time.perf_counter() - tic
prompt_tps = prompt.size / prompt_time
@@ -343,9 +355,6 @@ def stream_generate(
detokenizer.add_token(token)
if n == (max_tokens - 1):
break
yield GenerationResponse(
text=detokenizer.last_segment,
token=token,
@@ -385,7 +394,6 @@ def generate(
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt.
max_tokens (int): The maximum number of tokens. Default: ``100``.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
kwargs: The remaining options get passed to :func:`stream_generate`.