Add Repetitive penalty to LLM inference - mlx-lm (#399)

* feat: add repetition penalty

* fix: generate function argument fix

* typo fixes

* update repetitive penalty

* update generate_step and generate

* resolve conflicts in generate

* merge latest oull origin master

* update generate

* update generate and generate_step

* update repetition list - rename variable

* refactor token count

* update generate step and generate

* move repetition_context in generate_step

* update generate step

* update generate_step
This commit is contained in:
vishal-14069 2024-02-17 00:58:17 -05:00 committed by GitHub
parent 0ba466369f
commit 21e19b5b5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,7 +5,7 @@ import json
import logging import logging
import time import time
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Generator, Tuple, Union from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -80,10 +80,36 @@ def get_model_path(path_or_hf_repo: str) -> Path:
return model_path return model_path
def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: float):
"""
Apply repetition penalty to specific logits based on the given context.
Paper: https://arxiv.org/abs/1909.05858
Args:
logits (mx.array): The logits produced by the language model.
generated_tokens (any): A list of N previous tokens.
penalty (float): The repetition penalty factor to be applied.
Returns:
logits (mx.array): Logits with repetition penalty applied to generated tokens.
"""
if len(generated_tokens) > 0:
indices = mx.array([token for token in generated_tokens])
selected_logits = logits[:, indices]
selected_logits = mx.where(
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
)
logits[:, indices] = selected_logits
return logits
def generate_step( def generate_step(
prompt: mx.array, prompt: mx.array,
model: nn.Module, model: nn.Module,
temp: float = 0.0, temp: 0.0,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
) -> Generator[Tuple[mx.array, mx.array], None, None]: ) -> Generator[Tuple[mx.array, mx.array], None, None]:
""" """
A generator producing text based on the given prompt from the model. A generator producing text based on the given prompt from the model.
@ -92,6 +118,9 @@ def generate_step(
prompt (mx.array): The input prompt. prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation. model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling, if 0 the argmax is used. temp (float): The temperature for sampling, if 0 the argmax is used.
repetition_penalty (float, optional): The penalty factor for repeating tokens.
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty (default 20).
Yields: Yields:
Generator[Tuple[mx.array, mx.array]]: A generator producing Generator[Tuple[mx.array, mx.array]]: A generator producing
one token and probability per call. one token and probability per call.
@ -108,12 +137,37 @@ def generate_step(
prob = softmax_logits[0, token] prob = softmax_logits[0, token]
return token, prob return token, prob
if repetition_penalty and (
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
):
raise ValueError(
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
)
y = prompt y = prompt
cache = None cache = None
repetition_context = prompt.tolist()
if repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]
while True: while True:
logits, cache = model(y[None], cache=cache) logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :] logits = logits[:, -1, :]
if repetition_penalty:
logits = apply_repetition_penalty(
logits, repetition_context, repetition_penalty
)
y, prob = sample(logits) y, prob = sample(logits)
repetition_context.append(y.item())
else:
y, prob = sample(logits)
if repetition_context_size:
if len(repetition_context) > repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]
yield y, prob yield y, prob
@ -125,6 +179,8 @@ def generate(
max_tokens: int = 100, max_tokens: int = 100,
verbose: bool = False, verbose: bool = False,
formatter: Callable = None, formatter: Callable = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None,
) -> str: ) -> str:
""" """
Generate text from the model. Generate text from the model.
@ -139,20 +195,31 @@ def generate(
(default ``False``). (default ``False``).
formatter (Optional[Callable]): A function which takes a token and a formatter (Optional[Callable]): A function which takes a token and a
probability and displays it. probability and displays it.
repetition_penalty (float, optional): The penalty factor for repeating tokens.
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty.
""" """
if verbose: if verbose:
print("=" * 10) print("=" * 10)
print("Prompt:", prompt) print("Prompt:", prompt)
prompt = mx.array(tokenizer.encode(prompt)) prompt_tokens = mx.array(tokenizer.encode(prompt))
tic = time.perf_counter() tic = time.perf_counter()
tokens = [] tokens = []
skip = 0 skip = 0
REPLACEMENT_CHAR = "\ufffd" REPLACEMENT_CHAR = "\ufffd"
for (token, prob), n in zip(generate_step(prompt, model, temp), range(max_tokens)): for (token, prob), n in zip(
generate_step(
prompt_tokens,
model,
temp,
repetition_penalty,
repetition_context_size,
),
range(max_tokens),
):
if token == tokenizer.eos_token_id: if token == tokenizer.eos_token_id:
break break
if n == 0: if n == 0:
@ -179,7 +246,7 @@ def generate(
if token_count == 0: if token_count == 0:
print("No tokens generated for this prompt") print("No tokens generated for this prompt")
return return
prompt_tps = prompt.size / prompt_time prompt_tps = prompt_tokens.size / prompt_time
gen_tps = (token_count - 1) / gen_time gen_tps = (token_count - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec") print(f"Generation: {gen_tps:.3f} tokens-per-sec")