refactor of repetition_penalty and logits_bias to use logits_processor

This commit is contained in:
Nathan Ranchin
2024-09-29 18:01:37 +02:00
parent 7ec2021bb9
commit 39e5152ed8
2 changed files with 22 additions and 31 deletions

View File

@@ -158,7 +158,7 @@ def generate_step(
max_kv_size: Optional[int] = None,
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
logit_bias: Optional[Dict[int, float]] = None,
logits_processor: Optional[Callable[[mx.array, mx.array], mx.array]] = None,
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = [],
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@@ -182,9 +182,9 @@ def generate_step(
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
logit_bias (dictionary, optional): Additive logit bias.
logits_processor (Callable[[mx.array, mx.array], mx.array], optional):
A function that takes tokens and logits and returns the processed
logits. Default: ``None``.
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``[]``.
Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
@@ -212,6 +212,19 @@ def generate_step(
raise ValueError(
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
)
if repetition_penalty:
def repetition_penalty_processor(tokens: mx.array, logits: mx.array) -> mx.array:
return apply_repetition_penalty(logits, tokens[-repetition_context_size:], repetition_penalty)
logits_processor.append(repetition_penalty_processor)
if logit_bias:
def logit_bias_processor(_: mx.array, logits: mx.array) -> mx.array:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
logits[:, indices] += values
return logits
logits_processor.append(logit_bias_processor)
y = prompt
tokens = None
@@ -229,40 +242,18 @@ def generate_step(
c.update_and_fetch(h[0], h[1])
mx.eval([c.state for c in cache])
repetition_context = prompt.tolist()
if repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
def _step(y):
nonlocal repetition_context
logits = model(y[None], cache=cache)
logits = logits[:, -1, :]
if logits_processor:
nonlocal tokens
tokens = mx.concat([tokens, y]) if tokens is not None else y
logits = logits_processor(tokens, logits)
for processor in logits_processor:
logits = processor(tokens, logits)
if logit_bias:
logits[:, indices] += values
if repetition_penalty:
logits = apply_repetition_penalty(
logits, repetition_context, repetition_penalty
)
y, logprobs = sample(logits)
repetition_context.append(y.item())
else:
y, logprobs = sample(logits)
if repetition_context_size:
if len(repetition_context) > repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]
y, logprobs = sample(logits)
return y, logprobs.squeeze(0)
while y.size > prefill_step_size:

View File

@@ -46,7 +46,7 @@ class TestGenerate(unittest.TestCase):
"hello",
max_tokens=5,
verbose=False,
logits_processor=logits_processor,
logits_processor=[logits_processor],
)
self.assertEqual(len(all_toks), len(init_toks) + 5)