Min P implementation (#926)

* Min P implementation

* Change default to 0 (no min_p)

* nits

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Chime Ogbuji 2024-08-15 18:45:02 -04:00 committed by GitHub
parent 9b83004631
commit c50971e860
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 67 additions and 1 deletions

View File

@ -5,6 +5,64 @@ from functools import partial
import mlx.core as mx import mlx.core as mx
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def min_p_sampling(
logits: mx.array,
min_p: float,
min_tokens_to_keep: int = 1,
temperature=1.0,
) -> mx.array:
"""
Apply min-p sampling to the logits.
Min-p keeps all tokens that are above a minimum probability, scaled by the
probability of the most likely token. As a result, the filter is more
aggressive given a very high-probability token.
Args:
logits: The logits from the model's output.
min_p (float): Minimum token probability. Typical values are in the
0.01-0.2 range, comparably selective as setting `top_p` in the
0.99-0.8 range.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered. Default: ``1``.
"""
if not (0 <= min_p <= 1.0):
raise ValueError(
f"`min_p` has to be a float in the [0, 1] interval, but is {min_p}"
)
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
raise ValueError(
f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}"
)
# reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605
# Softmax probabilities
probs = mx.softmax(logits * (1 / temperature), axis=-1)
# Indices sorted in decreasing order
sorted_indices = mx.argsort(-logits).squeeze(0)
sorted_probs = probs[..., sorted_indices]
# Top probability
top_probs = probs[..., sorted_indices[0]]
# Calculate the min_p threshold
scaled_min_p = min_p * top_probs
# Mask tokens that have a probability less than the scaled min_p
tokens_to_remove = sorted_probs < scaled_min_p
tokens_to_remove[..., :min_tokens_to_keep] = False
# Create pool of tokens with probability less than scaled min_p
selected_probs = mx.where(tokens_to_remove, 0, sorted_probs)
# Return sampled token
sorted_token = mx.random.categorical(mx.log(selected_probs))
return sorted_indices[sorted_token]
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array: def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
""" """

View File

@ -20,7 +20,7 @@ from transformers import PreTrainedTokenizer
# Local imports # Local imports
from .models.base import KVCache from .models.base import KVCache
from .sample_utils import categorical_sampling, top_p_sampling from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import apply_lora_layers from .tuner.utils import apply_lora_layers
from .tuner.utils import dequantize as dequantize_model from .tuner.utils import dequantize as dequantize_model
@ -133,6 +133,8 @@ def generate_step(
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20, repetition_context_size: Optional[int] = 20,
top_p: float = 1.0, top_p: float = 1.0,
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
logit_bias: Optional[Dict[int, float]] = None, logit_bias: Optional[Dict[int, float]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]: ) -> Generator[Tuple[mx.array, mx.array], None, None]:
""" """
@ -149,6 +151,10 @@ def generate_step(
consider for repetition penalty. Default: ``20``. consider for repetition penalty. Default: ``20``.
top_p (float, optional): Nulceus sampling, higher means model considers top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words. more less likely words.
min_p (float, optional): The minimum value (scaled by the top token's
probability) that a token probability must have to be considered.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
logit_bias (dictionary, optional): Additive logit bias. logit_bias (dictionary, optional): Additive logit bias.
Yields: Yields:
@ -168,6 +174,8 @@ def generate_step(
else: else:
if top_p > 0 and top_p < 1.0: if top_p > 0 and top_p < 1.0:
token = top_p_sampling(logits, top_p, temp) token = top_p_sampling(logits, top_p, temp)
elif min_p != 0.0:
token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
else: else:
token = categorical_sampling(logits, temp) token = categorical_sampling(logits, temp)