mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
repetiton_penalty and logits_bias just using logits_processors (#1004)
* refactor of repetition_penalty and logits_bias to use logits_processor * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
418d9a5511
commit
0866e23a67
@ -101,7 +101,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
|
|||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: float):
|
def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float):
|
||||||
"""
|
"""
|
||||||
Apply repetition penalty to specific logits based on the given context.
|
Apply repetition penalty to specific logits based on the given context.
|
||||||
|
|
||||||
@ -109,19 +109,18 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
logits (mx.array): The logits produced by the language model.
|
logits (mx.array): The logits produced by the language model.
|
||||||
generated_tokens (any): A list of N previous tokens.
|
tokens (mx.array): A list of N previous tokens.
|
||||||
penalty (float): The repetition penalty factor to be applied.
|
penalty (float): The repetition penalty factor to be applied.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
logits (mx.array): Logits with repetition penalty applied to generated tokens.
|
logits (mx.array): Logits with repetition penalty applied to generated tokens.
|
||||||
"""
|
"""
|
||||||
if len(generated_tokens) > 0:
|
if len(tokens) > 0:
|
||||||
indices = mx.array([token for token in generated_tokens])
|
selected_logits = logits[:, tokens]
|
||||||
selected_logits = logits[:, indices]
|
|
||||||
selected_logits = mx.where(
|
selected_logits = mx.where(
|
||||||
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
|
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
|
||||||
)
|
)
|
||||||
logits[:, indices] = selected_logits
|
logits[:, tokens] = selected_logits
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
@ -158,7 +157,7 @@ def generate_step(
|
|||||||
max_kv_size: Optional[int] = None,
|
max_kv_size: Optional[int] = None,
|
||||||
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||||
logit_bias: Optional[Dict[int, float]] = None,
|
logit_bias: Optional[Dict[int, float]] = None,
|
||||||
logits_processor: Optional[Callable[[mx.array, mx.array], mx.array]] = None,
|
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing token ids based on the given prompt from the model.
|
A generator producing token ids based on the given prompt from the model.
|
||||||
@ -182,8 +181,8 @@ def generate_step(
|
|||||||
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||||
entries (except the first 4 tokens) will be overwritten.
|
entries (except the first 4 tokens) will be overwritten.
|
||||||
logit_bias (dictionary, optional): Additive logit bias.
|
logit_bias (dictionary, optional): Additive logit bias.
|
||||||
logits_processor (Callable[[mx.array, mx.array], mx.array], optional):
|
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||||
A function that takes tokens and logits and returns the processed
|
A list of functions that take tokens and logits and return the processed
|
||||||
logits. Default: ``None``.
|
logits. Default: ``None``.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
@ -213,6 +212,27 @@ def generate_step(
|
|||||||
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
|
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logits_processor = logits_processor or []
|
||||||
|
|
||||||
|
if repetition_penalty:
|
||||||
|
|
||||||
|
def repetition_penalty_processor(tokens, logits):
|
||||||
|
return apply_repetition_penalty(
|
||||||
|
logits, tokens[-repetition_context_size:], repetition_penalty
|
||||||
|
)
|
||||||
|
|
||||||
|
logits_processor.append(repetition_penalty_processor)
|
||||||
|
|
||||||
|
if logit_bias:
|
||||||
|
indices = mx.array(list(logit_bias.keys()))
|
||||||
|
values = mx.array(list(logit_bias.values()))
|
||||||
|
|
||||||
|
def logit_bias_processor(_, logits):
|
||||||
|
logits[:, indices] += values
|
||||||
|
return logits
|
||||||
|
|
||||||
|
logits_processor.append(logit_bias_processor)
|
||||||
|
|
||||||
y = prompt
|
y = prompt
|
||||||
tokens = None
|
tokens = None
|
||||||
|
|
||||||
@ -229,40 +249,18 @@ def generate_step(
|
|||||||
c.update_and_fetch(h[0], h[1])
|
c.update_and_fetch(h[0], h[1])
|
||||||
mx.eval([c.state for c in cache])
|
mx.eval([c.state for c in cache])
|
||||||
|
|
||||||
repetition_context = prompt.tolist()
|
|
||||||
|
|
||||||
if repetition_context_size:
|
|
||||||
repetition_context = repetition_context[-repetition_context_size:]
|
|
||||||
|
|
||||||
if logit_bias:
|
|
||||||
indices = mx.array(list(logit_bias.keys()))
|
|
||||||
values = mx.array(list(logit_bias.values()))
|
|
||||||
|
|
||||||
def _step(y):
|
def _step(y):
|
||||||
nonlocal repetition_context
|
|
||||||
logits = model(y[None], cache=cache)
|
logits = model(y[None], cache=cache)
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
|
|
||||||
if logits_processor:
|
if logits_processor:
|
||||||
nonlocal tokens
|
nonlocal tokens
|
||||||
tokens = mx.concat([tokens, y]) if tokens is not None else y
|
tokens = mx.concat([tokens, y]) if tokens is not None else y
|
||||||
logits = logits_processor(tokens, logits)
|
|
||||||
|
|
||||||
if logit_bias:
|
for processor in logits_processor:
|
||||||
logits[:, indices] += values
|
logits = processor(tokens, logits)
|
||||||
|
|
||||||
if repetition_penalty:
|
|
||||||
logits = apply_repetition_penalty(
|
|
||||||
logits, repetition_context, repetition_penalty
|
|
||||||
)
|
|
||||||
y, logprobs = sample(logits)
|
y, logprobs = sample(logits)
|
||||||
repetition_context.append(y.item())
|
|
||||||
else:
|
|
||||||
y, logprobs = sample(logits)
|
|
||||||
|
|
||||||
if repetition_context_size:
|
|
||||||
if len(repetition_context) > repetition_context_size:
|
|
||||||
repetition_context = repetition_context[-repetition_context_size:]
|
|
||||||
return y, logprobs.squeeze(0)
|
return y, logprobs.squeeze(0)
|
||||||
|
|
||||||
while y.size > prefill_step_size:
|
while y.size > prefill_step_size:
|
||||||
|
@ -46,7 +46,7 @@ class TestGenerate(unittest.TestCase):
|
|||||||
"hello",
|
"hello",
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
logits_processor=logits_processor,
|
logits_processor=[logits_processor],
|
||||||
)
|
)
|
||||||
self.assertEqual(len(all_toks), len(init_toks) + 5)
|
self.assertEqual(len(all_toks), len(init_toks) + 5)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user