mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Min P implementation (#926)
* Min P implementation * Change default to 0 (no min_p) * nits * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
9b83004631
commit
c50971e860
@ -5,6 +5,64 @@ from functools import partial
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||||
|
def min_p_sampling(
|
||||||
|
logits: mx.array,
|
||||||
|
min_p: float,
|
||||||
|
min_tokens_to_keep: int = 1,
|
||||||
|
temperature=1.0,
|
||||||
|
) -> mx.array:
|
||||||
|
"""
|
||||||
|
Apply min-p sampling to the logits.
|
||||||
|
|
||||||
|
Min-p keeps all tokens that are above a minimum probability, scaled by the
|
||||||
|
probability of the most likely token. As a result, the filter is more
|
||||||
|
aggressive given a very high-probability token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: The logits from the model's output.
|
||||||
|
min_p (float): Minimum token probability. Typical values are in the
|
||||||
|
0.01-0.2 range, comparably selective as setting `top_p` in the
|
||||||
|
0.99-0.8 range.
|
||||||
|
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||||
|
be filtered. Default: ``1``.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not (0 <= min_p <= 1.0):
|
||||||
|
raise ValueError(
|
||||||
|
f"`min_p` has to be a float in the [0, 1] interval, but is {min_p}"
|
||||||
|
)
|
||||||
|
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
|
||||||
|
raise ValueError(
|
||||||
|
f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}"
|
||||||
|
)
|
||||||
|
# reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605
|
||||||
|
|
||||||
|
# Softmax probabilities
|
||||||
|
probs = mx.softmax(logits * (1 / temperature), axis=-1)
|
||||||
|
|
||||||
|
# Indices sorted in decreasing order
|
||||||
|
sorted_indices = mx.argsort(-logits).squeeze(0)
|
||||||
|
sorted_probs = probs[..., sorted_indices]
|
||||||
|
|
||||||
|
# Top probability
|
||||||
|
top_probs = probs[..., sorted_indices[0]]
|
||||||
|
|
||||||
|
# Calculate the min_p threshold
|
||||||
|
scaled_min_p = min_p * top_probs
|
||||||
|
|
||||||
|
# Mask tokens that have a probability less than the scaled min_p
|
||||||
|
tokens_to_remove = sorted_probs < scaled_min_p
|
||||||
|
tokens_to_remove[..., :min_tokens_to_keep] = False
|
||||||
|
|
||||||
|
# Create pool of tokens with probability less than scaled min_p
|
||||||
|
selected_probs = mx.where(tokens_to_remove, 0, sorted_probs)
|
||||||
|
|
||||||
|
# Return sampled token
|
||||||
|
sorted_token = mx.random.categorical(mx.log(selected_probs))
|
||||||
|
return sorted_indices[sorted_token]
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||||
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
|
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
|
||||||
"""
|
"""
|
||||||
|
@ -20,7 +20,7 @@ from transformers import PreTrainedTokenizer
|
|||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
from .models.base import KVCache
|
from .models.base import KVCache
|
||||||
from .sample_utils import categorical_sampling, top_p_sampling
|
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
|
||||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||||
from .tuner.utils import apply_lora_layers
|
from .tuner.utils import apply_lora_layers
|
||||||
from .tuner.utils import dequantize as dequantize_model
|
from .tuner.utils import dequantize as dequantize_model
|
||||||
@ -133,6 +133,8 @@ def generate_step(
|
|||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
repetition_context_size: Optional[int] = 20,
|
repetition_context_size: Optional[int] = 20,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
|
min_p: float = 0.0,
|
||||||
|
min_tokens_to_keep: int = 1,
|
||||||
logit_bias: Optional[Dict[int, float]] = None,
|
logit_bias: Optional[Dict[int, float]] = None,
|
||||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||||
"""
|
"""
|
||||||
@ -149,6 +151,10 @@ def generate_step(
|
|||||||
consider for repetition penalty. Default: ``20``.
|
consider for repetition penalty. Default: ``20``.
|
||||||
top_p (float, optional): Nulceus sampling, higher means model considers
|
top_p (float, optional): Nulceus sampling, higher means model considers
|
||||||
more less likely words.
|
more less likely words.
|
||||||
|
min_p (float, optional): The minimum value (scaled by the top token's
|
||||||
|
probability) that a token probability must have to be considered.
|
||||||
|
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||||
|
be filtered by min_p sampling.
|
||||||
logit_bias (dictionary, optional): Additive logit bias.
|
logit_bias (dictionary, optional): Additive logit bias.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
@ -168,6 +174,8 @@ def generate_step(
|
|||||||
else:
|
else:
|
||||||
if top_p > 0 and top_p < 1.0:
|
if top_p > 0 and top_p < 1.0:
|
||||||
token = top_p_sampling(logits, top_p, temp)
|
token = top_p_sampling(logits, top_p, temp)
|
||||||
|
elif min_p != 0.0:
|
||||||
|
token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
|
||||||
else:
|
else:
|
||||||
token = categorical_sampling(logits, temp)
|
token = categorical_sampling(logits, temp)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user