mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
45 lines
1.4 KiB
Python
45 lines
1.4 KiB
Python
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
from functools import partial
|
|
|
|
import mlx.core as mx
|
|
|
|
|
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
|
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
|
|
"""
|
|
Apply top-p (nucleus) sampling to logits.
|
|
|
|
Args:
|
|
logits: The logits from the model's output.
|
|
top_p: The cumulative probability threshold for top-p filtering.
|
|
temperature: Temperature parameter for softmax distribution reshaping.
|
|
Returns:
|
|
token selected based on the top-p criterion.
|
|
"""
|
|
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
|
|
probs = mx.softmax(logits * (1 / temperature), axis=-1)
|
|
|
|
# sort probs in ascending order
|
|
sorted_indices = mx.argsort(probs, axis=-1)
|
|
sorted_probs = probs[..., sorted_indices.squeeze(0)]
|
|
|
|
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
|
|
|
# select tokens with cumulative probs below threshold
|
|
top_probs = mx.where(
|
|
cumulative_probs > 1 - top_p,
|
|
sorted_probs,
|
|
0,
|
|
)
|
|
|
|
sorted_token = mx.random.categorical(mx.log(top_probs))
|
|
token = sorted_indices.squeeze(0)[sorted_token]
|
|
|
|
return token
|
|
|
|
|
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
|
def categorical_sampling(logits, temp):
|
|
return mx.random.categorical(logits * (1 / temp))
|