2024-03-22 03:18:23 +08:00
|
|
|
import mlx.core as mx
|
|
|
|
|
|
|
|
|
|
|
|
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
|
|
|
|
"""
|
|
|
|
Apply top-p (nucleus) sampling to logits.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
logits: The logits from the model's output.
|
|
|
|
top_p: The cumulative probability threshold for top-p filtering.
|
|
|
|
temperature: Temperature parameter for softmax distribution reshaping.
|
|
|
|
Returns:
|
|
|
|
token selected based on the top-p criterion.
|
|
|
|
"""
|
|
|
|
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
|
|
|
|
probs = mx.softmax(logits / temperature, axis=-1)
|
|
|
|
|
|
|
|
# sort probs in ascending order
|
|
|
|
sorted_indices = mx.argsort(probs, axis=-1)
|
2024-03-26 06:07:55 +08:00
|
|
|
sorted_probs = probs[..., sorted_indices.squeeze(0)]
|
2024-03-22 03:18:23 +08:00
|
|
|
|
|
|
|
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
|
|
|
|
|
|
|
# select tokens with cumulative probs below threshold
|
|
|
|
top_probs = mx.where(
|
|
|
|
cumulative_probs > 1 - top_p,
|
|
|
|
sorted_probs,
|
|
|
|
mx.zeros_like(sorted_probs),
|
|
|
|
)
|
|
|
|
|
|
|
|
sorted_token = mx.random.categorical(mx.log(top_probs))
|
|
|
|
token = sorted_indices.squeeze(0)[sorted_token]
|
|
|
|
|
|
|
|
return token
|