mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Add Repetitive penalty to LLM inference - mlx-lm (#399)
* feat: add repetition penalty * fix: generate function argument fix * typo fixes * update repetitive penalty * update generate_step and generate * resolve conflicts in generate * merge latest oull origin master * update generate * update generate and generate_step * update repetition list - rename variable * refactor token count * update generate step and generate * move repetition_context in generate_step * update generate step * update generate_step
This commit is contained in:
parent
0ba466369f
commit
21e19b5b5a
@ -5,7 +5,7 @@ import json
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Generator, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -80,10 +80,36 @@ def get_model_path(path_or_hf_repo: str) -> Path:
|
||||
return model_path
|
||||
|
||||
|
||||
def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: float):
|
||||
"""
|
||||
Apply repetition penalty to specific logits based on the given context.
|
||||
|
||||
Paper: https://arxiv.org/abs/1909.05858
|
||||
|
||||
Args:
|
||||
logits (mx.array): The logits produced by the language model.
|
||||
generated_tokens (any): A list of N previous tokens.
|
||||
penalty (float): The repetition penalty factor to be applied.
|
||||
|
||||
Returns:
|
||||
logits (mx.array): Logits with repetition penalty applied to generated tokens.
|
||||
"""
|
||||
if len(generated_tokens) > 0:
|
||||
indices = mx.array([token for token in generated_tokens])
|
||||
selected_logits = logits[:, indices]
|
||||
selected_logits = mx.where(
|
||||
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
|
||||
)
|
||||
logits[:, indices] = selected_logits
|
||||
return logits
|
||||
|
||||
|
||||
def generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
temp: float = 0.0,
|
||||
temp: 0.0,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = 20,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing text based on the given prompt from the model.
|
||||
@ -92,6 +118,9 @@ def generate_step(
|
||||
prompt (mx.array): The input prompt.
|
||||
model (nn.Module): The model to use for generation.
|
||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||
repetition_penalty (float, optional): The penalty factor for repeating tokens.
|
||||
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty (default 20).
|
||||
|
||||
Yields:
|
||||
Generator[Tuple[mx.array, mx.array]]: A generator producing
|
||||
one token and probability per call.
|
||||
@ -108,12 +137,37 @@ def generate_step(
|
||||
prob = softmax_logits[0, token]
|
||||
return token, prob
|
||||
|
||||
if repetition_penalty and (
|
||||
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
|
||||
):
|
||||
raise ValueError(
|
||||
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
|
||||
)
|
||||
|
||||
y = prompt
|
||||
cache = None
|
||||
|
||||
repetition_context = prompt.tolist()
|
||||
|
||||
if repetition_context_size:
|
||||
repetition_context = repetition_context[-repetition_context_size:]
|
||||
|
||||
while True:
|
||||
logits, cache = model(y[None], cache=cache)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if repetition_penalty:
|
||||
logits = apply_repetition_penalty(
|
||||
logits, repetition_context, repetition_penalty
|
||||
)
|
||||
y, prob = sample(logits)
|
||||
repetition_context.append(y.item())
|
||||
else:
|
||||
y, prob = sample(logits)
|
||||
|
||||
if repetition_context_size:
|
||||
if len(repetition_context) > repetition_context_size:
|
||||
repetition_context = repetition_context[-repetition_context_size:]
|
||||
yield y, prob
|
||||
|
||||
|
||||
@ -125,6 +179,8 @@ def generate(
|
||||
max_tokens: int = 100,
|
||||
verbose: bool = False,
|
||||
formatter: Callable = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate text from the model.
|
||||
@ -139,20 +195,31 @@ def generate(
|
||||
(default ``False``).
|
||||
formatter (Optional[Callable]): A function which takes a token and a
|
||||
probability and displays it.
|
||||
repetition_penalty (float, optional): The penalty factor for repeating tokens.
|
||||
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty.
|
||||
"""
|
||||
|
||||
if verbose:
|
||||
print("=" * 10)
|
||||
print("Prompt:", prompt)
|
||||
|
||||
prompt = mx.array(tokenizer.encode(prompt))
|
||||
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
||||
|
||||
tic = time.perf_counter()
|
||||
tokens = []
|
||||
skip = 0
|
||||
REPLACEMENT_CHAR = "\ufffd"
|
||||
|
||||
for (token, prob), n in zip(generate_step(prompt, model, temp), range(max_tokens)):
|
||||
for (token, prob), n in zip(
|
||||
generate_step(
|
||||
prompt_tokens,
|
||||
model,
|
||||
temp,
|
||||
repetition_penalty,
|
||||
repetition_context_size,
|
||||
),
|
||||
range(max_tokens),
|
||||
):
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
if n == 0:
|
||||
@ -179,7 +246,7 @@ def generate(
|
||||
if token_count == 0:
|
||||
print("No tokens generated for this prompt")
|
||||
return
|
||||
prompt_tps = prompt.size / prompt_time
|
||||
prompt_tps = prompt_tokens.size / prompt_time
|
||||
gen_tps = (token_count - 1) / gen_time
|
||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||
|
Loading…
Reference in New Issue
Block a user