From c50971e8608e05c42d8078b7009dc5b9c33d25c0 Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Thu, 15 Aug 2024 18:45:02 -0400 Subject: [PATCH] Min P implementation (#926) * Min P implementation * Change default to 0 (no min_p) * nits * nits --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/sample_utils.py | 58 +++++++++++++++++++++++++++++++++++++ llms/mlx_lm/utils.py | 10 ++++++- 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index 2e9c172e..20b008fa 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -5,6 +5,64 @@ from functools import partial import mlx.core as mx +@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) +def min_p_sampling( + logits: mx.array, + min_p: float, + min_tokens_to_keep: int = 1, + temperature=1.0, +) -> mx.array: + """ + Apply min-p sampling to the logits. + + Min-p keeps all tokens that are above a minimum probability, scaled by the + probability of the most likely token. As a result, the filter is more + aggressive given a very high-probability token. + + Args: + logits: The logits from the model's output. + min_p (float): Minimum token probability. Typical values are in the + 0.01-0.2 range, comparably selective as setting `top_p` in the + 0.99-0.8 range. + min_tokens_to_keep (int, optional): Minimum number of tokens that cannot + be filtered. Default: ``1``. + + """ + if not (0 <= min_p <= 1.0): + raise ValueError( + f"`min_p` has to be a float in the [0, 1] interval, but is {min_p}" + ) + if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1): + raise ValueError( + f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}" + ) + # reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605 + + # Softmax probabilities + probs = mx.softmax(logits * (1 / temperature), axis=-1) + + # Indices sorted in decreasing order + sorted_indices = mx.argsort(-logits).squeeze(0) + sorted_probs = probs[..., sorted_indices] + + # Top probability + top_probs = probs[..., sorted_indices[0]] + + # Calculate the min_p threshold + scaled_min_p = min_p * top_probs + + # Mask tokens that have a probability less than the scaled min_p + tokens_to_remove = sorted_probs < scaled_min_p + tokens_to_remove[..., :min_tokens_to_keep] = False + + # Create pool of tokens with probability less than scaled min_p + selected_probs = mx.where(tokens_to_remove, 0, sorted_probs) + + # Return sampled token + sorted_token = mx.random.categorical(mx.log(selected_probs)) + return sorted_indices[sorted_token] + + @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array: """ diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index a34cc6ad..e7a9dba8 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -20,7 +20,7 @@ from transformers import PreTrainedTokenizer # Local imports from .models.base import KVCache -from .sample_utils import categorical_sampling, top_p_sampling +from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import apply_lora_layers from .tuner.utils import dequantize as dequantize_model @@ -133,6 +133,8 @@ def generate_step( repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = 20, top_p: float = 1.0, + min_p: float = 0.0, + min_tokens_to_keep: int = 1, logit_bias: Optional[Dict[int, float]] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ @@ -149,6 +151,10 @@ def generate_step( consider for repetition penalty. Default: ``20``. top_p (float, optional): Nulceus sampling, higher means model considers more less likely words. + min_p (float, optional): The minimum value (scaled by the top token's + probability) that a token probability must have to be considered. + min_tokens_to_keep (int, optional): Minimum number of tokens that cannot + be filtered by min_p sampling. logit_bias (dictionary, optional): Additive logit bias. Yields: @@ -168,6 +174,8 @@ def generate_step( else: if top_p > 0 and top_p < 1.0: token = top_p_sampling(logits, top_p, temp) + elif min_p != 0.0: + token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp) else: token = categorical_sampling(logits, temp)