mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
chore(mlx-lm): fix the top_p implementation. (#602)
* chore(mlx-lm): clean up the top p imp * chore: clean up * chore: add test * chore: address comments * chore: clean up docs string * chore: clean up test
This commit is contained in:
parent
fe96ef342f
commit
fbed720d6f
39
llms/mlx_lm/sample_utils.py
Normal file
39
llms/mlx_lm/sample_utils.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
|
||||||
|
"""
|
||||||
|
Apply top-p (nucleus) sampling to logits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: The logits from the model's output.
|
||||||
|
top_p: The cumulative probability threshold for top-p filtering.
|
||||||
|
temperature: Temperature parameter for softmax distribution reshaping.
|
||||||
|
Returns:
|
||||||
|
token selected based on the top-p criterion.
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
logits.dtype == mx.bfloat16
|
||||||
|
): # workaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16
|
||||||
|
logits = logits.astype(mx.float32)
|
||||||
|
|
||||||
|
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
|
||||||
|
probs = mx.softmax(logits / temperature, axis=-1)
|
||||||
|
|
||||||
|
# sort probs in ascending order
|
||||||
|
sorted_indices = mx.argsort(probs, axis=-1)
|
||||||
|
sorted_probs = probs[..., sorted_indices]
|
||||||
|
|
||||||
|
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
||||||
|
|
||||||
|
# select tokens with cumulative probs below threshold
|
||||||
|
top_probs = mx.where(
|
||||||
|
cumulative_probs > 1 - top_p,
|
||||||
|
sorted_probs,
|
||||||
|
mx.zeros_like(sorted_probs),
|
||||||
|
)
|
||||||
|
|
||||||
|
sorted_token = mx.random.categorical(mx.log(top_probs))
|
||||||
|
token = sorted_indices.squeeze(0)[sorted_token]
|
||||||
|
|
||||||
|
return token
|
@ -17,6 +17,8 @@ from huggingface_hub import snapshot_download
|
|||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
||||||
|
|
||||||
|
from .sample_utils import top_p_sampling
|
||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
from .tuner.utils import apply_lora_layers
|
from .tuner.utils import apply_lora_layers
|
||||||
from .tuner.utils import dequantize as dequantize_model
|
from .tuner.utils import dequantize as dequantize_model
|
||||||
@ -144,23 +146,7 @@ def generate_step(
|
|||||||
token = mx.argmax(logits, axis=-1)
|
token = mx.argmax(logits, axis=-1)
|
||||||
else:
|
else:
|
||||||
if top_p > 0 and top_p < 1.0:
|
if top_p > 0 and top_p < 1.0:
|
||||||
if (
|
token = top_p_sampling(logits, top_p, temp)
|
||||||
logits.dtype == mx.bfloat16
|
|
||||||
): # workdaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16
|
|
||||||
logits = logits.astype(mx.float32)
|
|
||||||
probs = mx.softmax(logits / temp, axis=-1)
|
|
||||||
|
|
||||||
sorted_probs = mx.sort(probs)[::-1]
|
|
||||||
sorted_indices = mx.argsort(probs)[::-1]
|
|
||||||
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
|
||||||
|
|
||||||
top_probs = mx.where(
|
|
||||||
cumulative_probs > 1 - top_p,
|
|
||||||
sorted_probs,
|
|
||||||
mx.zeros_like(sorted_probs),
|
|
||||||
)
|
|
||||||
sorted_token = mx.random.categorical(mx.log(top_probs))
|
|
||||||
token = sorted_indices.squeeze(0)[sorted_token]
|
|
||||||
else:
|
else:
|
||||||
token = mx.random.categorical(logits * (1 / temp))
|
token = mx.random.categorical(logits * (1 / temp))
|
||||||
|
|
||||||
|
37
llms/tests/test_sample_utils.py
Normal file
37
llms/tests/test_sample_utils.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx_lm.sample_utils import top_p_sampling
|
||||||
|
|
||||||
|
|
||||||
|
class TestLora(unittest.TestCase):
|
||||||
|
@patch("mlx.core.random.categorical")
|
||||||
|
def test_top_p_sampling(self, mock_categorical):
|
||||||
|
logits = mx.array([[1.0, 2.0, 3.0, 4.0]])
|
||||||
|
top_p = 0.3
|
||||||
|
temperature = 1.0
|
||||||
|
expected_token = mx.array([3])
|
||||||
|
mock_categorical.return_value = expected_token
|
||||||
|
|
||||||
|
token = top_p_sampling(logits, top_p, temperature)
|
||||||
|
expected_top_probs = mx.array([[0.0, 0.0, 0.0, 0.643914]])
|
||||||
|
self.assertTrue(mx.allclose(token, expected_token))
|
||||||
|
args, _ = mock_categorical.call_args
|
||||||
|
self.assertTrue(mx.allclose(args[0], mx.log(expected_top_probs)))
|
||||||
|
|
||||||
|
logits = mx.array([[1.0, 2.0, 3.0, 4.0]])
|
||||||
|
top_p = 0.9
|
||||||
|
temperature = 1.0
|
||||||
|
expected_token = mx.array([3])
|
||||||
|
mock_categorical.return_value = expected_token
|
||||||
|
|
||||||
|
token = top_p_sampling(logits, top_p, temperature)
|
||||||
|
expected_top_probs = mx.array([[0.0, 0.0871443, 0.236883, 0.643914]])
|
||||||
|
self.assertTrue(mx.allclose(token, expected_token))
|
||||||
|
args, _ = mock_categorical.call_args
|
||||||
|
self.assertTrue(mx.allclose(args[0], mx.log(expected_top_probs)))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user