mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Add Repetitive penalty to LLM inference - mlx-lm (#399)
* feat: add repetition penalty * fix: generate function argument fix * typo fixes * update repetitive penalty * update generate_step and generate * resolve conflicts in generate * merge latest oull origin master * update generate * update generate and generate_step * update repetition list - rename variable * refactor token count * update generate step and generate * move repetition_context in generate_step * update generate step * update generate_step
This commit is contained in:
parent
0ba466369f
commit
21e19b5b5a
@ -5,7 +5,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Generator, Tuple, Union
|
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -80,10 +80,36 @@ def get_model_path(path_or_hf_repo: str) -> Path:
|
|||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
|
def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: float):
|
||||||
|
"""
|
||||||
|
Apply repetition penalty to specific logits based on the given context.
|
||||||
|
|
||||||
|
Paper: https://arxiv.org/abs/1909.05858
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits (mx.array): The logits produced by the language model.
|
||||||
|
generated_tokens (any): A list of N previous tokens.
|
||||||
|
penalty (float): The repetition penalty factor to be applied.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
logits (mx.array): Logits with repetition penalty applied to generated tokens.
|
||||||
|
"""
|
||||||
|
if len(generated_tokens) > 0:
|
||||||
|
indices = mx.array([token for token in generated_tokens])
|
||||||
|
selected_logits = logits[:, indices]
|
||||||
|
selected_logits = mx.where(
|
||||||
|
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
|
||||||
|
)
|
||||||
|
logits[:, indices] = selected_logits
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def generate_step(
|
def generate_step(
|
||||||
prompt: mx.array,
|
prompt: mx.array,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
temp: float = 0.0,
|
temp: 0.0,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
repetition_context_size: Optional[int] = 20,
|
||||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing text based on the given prompt from the model.
|
A generator producing text based on the given prompt from the model.
|
||||||
@ -92,6 +118,9 @@ def generate_step(
|
|||||||
prompt (mx.array): The input prompt.
|
prompt (mx.array): The input prompt.
|
||||||
model (nn.Module): The model to use for generation.
|
model (nn.Module): The model to use for generation.
|
||||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||||
|
repetition_penalty (float, optional): The penalty factor for repeating tokens.
|
||||||
|
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty (default 20).
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Generator[Tuple[mx.array, mx.array]]: A generator producing
|
Generator[Tuple[mx.array, mx.array]]: A generator producing
|
||||||
one token and probability per call.
|
one token and probability per call.
|
||||||
@ -108,12 +137,37 @@ def generate_step(
|
|||||||
prob = softmax_logits[0, token]
|
prob = softmax_logits[0, token]
|
||||||
return token, prob
|
return token, prob
|
||||||
|
|
||||||
|
if repetition_penalty and (
|
||||||
|
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
|
||||||
|
)
|
||||||
|
|
||||||
y = prompt
|
y = prompt
|
||||||
cache = None
|
cache = None
|
||||||
|
|
||||||
|
repetition_context = prompt.tolist()
|
||||||
|
|
||||||
|
if repetition_context_size:
|
||||||
|
repetition_context = repetition_context[-repetition_context_size:]
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
logits, cache = model(y[None], cache=cache)
|
logits, cache = model(y[None], cache=cache)
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
|
|
||||||
|
if repetition_penalty:
|
||||||
|
logits = apply_repetition_penalty(
|
||||||
|
logits, repetition_context, repetition_penalty
|
||||||
|
)
|
||||||
y, prob = sample(logits)
|
y, prob = sample(logits)
|
||||||
|
repetition_context.append(y.item())
|
||||||
|
else:
|
||||||
|
y, prob = sample(logits)
|
||||||
|
|
||||||
|
if repetition_context_size:
|
||||||
|
if len(repetition_context) > repetition_context_size:
|
||||||
|
repetition_context = repetition_context[-repetition_context_size:]
|
||||||
yield y, prob
|
yield y, prob
|
||||||
|
|
||||||
|
|
||||||
@ -125,6 +179,8 @@ def generate(
|
|||||||
max_tokens: int = 100,
|
max_tokens: int = 100,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
formatter: Callable = None,
|
formatter: Callable = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
repetition_context_size: Optional[int] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate text from the model.
|
Generate text from the model.
|
||||||
@ -139,20 +195,31 @@ def generate(
|
|||||||
(default ``False``).
|
(default ``False``).
|
||||||
formatter (Optional[Callable]): A function which takes a token and a
|
formatter (Optional[Callable]): A function which takes a token and a
|
||||||
probability and displays it.
|
probability and displays it.
|
||||||
|
repetition_penalty (float, optional): The penalty factor for repeating tokens.
|
||||||
|
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print("=" * 10)
|
print("=" * 10)
|
||||||
print("Prompt:", prompt)
|
print("Prompt:", prompt)
|
||||||
|
|
||||||
prompt = mx.array(tokenizer.encode(prompt))
|
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
||||||
|
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
tokens = []
|
tokens = []
|
||||||
skip = 0
|
skip = 0
|
||||||
REPLACEMENT_CHAR = "\ufffd"
|
REPLACEMENT_CHAR = "\ufffd"
|
||||||
|
|
||||||
for (token, prob), n in zip(generate_step(prompt, model, temp), range(max_tokens)):
|
for (token, prob), n in zip(
|
||||||
|
generate_step(
|
||||||
|
prompt_tokens,
|
||||||
|
model,
|
||||||
|
temp,
|
||||||
|
repetition_penalty,
|
||||||
|
repetition_context_size,
|
||||||
|
),
|
||||||
|
range(max_tokens),
|
||||||
|
):
|
||||||
if token == tokenizer.eos_token_id:
|
if token == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
if n == 0:
|
if n == 0:
|
||||||
@ -179,7 +246,7 @@ def generate(
|
|||||||
if token_count == 0:
|
if token_count == 0:
|
||||||
print("No tokens generated for this prompt")
|
print("No tokens generated for this prompt")
|
||||||
return
|
return
|
||||||
prompt_tps = prompt.size / prompt_time
|
prompt_tps = prompt_tokens.size / prompt_time
|
||||||
gen_tps = (token_count - 1) / gen_time
|
gen_tps = (token_count - 1) / gen_time
|
||||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||||
|
Loading…
Reference in New Issue
Block a user