add back logit bias + test

This commit is contained in:
Awni Hannun
2024-09-28 09:55:43 -07:00
parent c8216caa61
commit 83aaf0c98b
2 changed files with 21 additions and 0 deletions

View File

@@ -157,6 +157,7 @@ def generate_step(
prefill_step_size: int = 512,
max_kv_size: Optional[int] = None,
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
logit_bias: Optional[Dict[int, float]] = None,
logits_processor: Optional[Callable[[mx.array, mx.array], mx.array]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
@@ -180,6 +181,7 @@ def generate_step(
prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
logit_bias (dictionary, optional): Additive logit bias.
logits_processor (Callable[[mx.array, mx.array], mx.array], optional):
A function that takes tokens and logits and returns the processed
logits. Default: ``None``.
@@ -232,6 +234,10 @@ def generate_step(
if repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
def _step(y):
nonlocal repetition_context
logits = model(y[None], cache=cache)
@@ -242,6 +248,9 @@ def generate_step(
tokens = mx.concat([tokens, y]) if tokens is not None else y
logits = logits_processor(tokens, logits)
if logit_bias:
logits[:, indices] += values
if repetition_penalty:
logits = apply_repetition_penalty(
logits, repetition_context, repetition_penalty

View File

@@ -18,6 +18,18 @@ class TestGenerate(unittest.TestCase):
self.model, self.tokenizer, "hello", max_tokens=5, verbose=False
)
def test_generate_with_logit_bias(self):
logit_bias = {0: 2000.0, 1: -20.0}
text = generate(
self.model,
self.tokenizer,
"hello",
max_tokens=5,
verbose=False,
logit_bias=logit_bias,
)
self.assertEqual(text, "!!!!!")
def test_generate_with_processor(self):
init_toks = self.tokenizer.encode("hello")