Add logits_processor option to generate_step function (#983)

* Add logits_processor option for the generation as in huggingface transformers library

* concatenation correction

* Rename the tokens variable for clarity

* remove the logit_bias argument from generate_step method

* fix the variable name

* nits + test

* test

* add back logit bias + test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
nathan
2024-09-28 19:08:49 +02:00
committed by GitHub
parent d812516d3d
commit ace2bb5890
2 changed files with 74 additions and 6 deletions

View File

@@ -154,10 +154,11 @@ def generate_step(
top_p: float = 1.0,
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
logit_bias: Optional[Dict[int, float]] = None,
prefill_step_size: int = 512,
max_kv_size: Optional[int] = None,
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
logit_bias: Optional[Dict[int, float]] = None,
logits_processor: Optional[Callable[[mx.array, mx.array], mx.array]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@@ -177,10 +178,13 @@ def generate_step(
probability) that a token probability must have to be considered.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
logit_bias (dictionary, optional): Additive logit bias.
prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
logit_bias (dictionary, optional): Additive logit bias.
logits_processor (Callable[[mx.array, mx.array], mx.array], optional):
A function that takes tokens and logits and returns the processed
logits. Default: ``None``.
Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
@@ -188,10 +192,6 @@ def generate_step(
"""
def sample(logits: mx.array) -> Tuple[mx.array, float]:
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
logits[:, indices] += values
logprobs = logits - mx.logsumexp(logits)
if temp == 0:
@@ -214,6 +214,7 @@ def generate_step(
)
y = prompt
tokens = None
# Create the KV cache for generation
cache = make_kv_caches(model, max_kv_size)
@@ -233,11 +234,23 @@ def generate_step(
if repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
def _step(y):
nonlocal repetition_context
logits = model(y[None], cache=cache)
logits = logits[:, -1, :]
if logits_processor:
nonlocal tokens
tokens = mx.concat([tokens, y]) if tokens is not None else y
logits = logits_processor(tokens, logits)
if logit_bias:
logits[:, indices] += values
if repetition_penalty:
logits = apply_repetition_penalty(
logits, repetition_context, repetition_penalty