diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index f22ce2d7..2e9c172e 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -1,6 +1,11 @@ +# Copyright © 2023-2024 Apple Inc. + +from functools import partial + import mlx.core as mx +@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array: """ Apply top-p (nucleus) sampling to logits. @@ -13,7 +18,7 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr token selected based on the top-p criterion. """ # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460 - probs = mx.softmax(logits / temperature, axis=-1) + probs = mx.softmax(logits * (1 / temperature), axis=-1) # sort probs in ascending order sorted_indices = mx.argsort(probs, axis=-1) @@ -25,10 +30,15 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr top_probs = mx.where( cumulative_probs > 1 - top_p, sorted_probs, - mx.zeros_like(sorted_probs), + 0, ) sorted_token = mx.random.categorical(mx.log(top_probs)) token = sorted_indices.squeeze(0)[sorted_token] return token + + +@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) +def categorical_sampling(logits, temp): + return mx.random.categorical(logits * (1 / temp)) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 0e7f7a39..a34cc6ad 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -20,7 +20,7 @@ from transformers import PreTrainedTokenizer # Local imports from .models.base import KVCache -from .sample_utils import top_p_sampling +from .sample_utils import categorical_sampling, top_p_sampling from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import apply_lora_layers from .tuner.utils import dequantize as dequantize_model @@ -169,7 +169,7 @@ def generate_step( if top_p > 0 and top_p < 1.0: token = top_p_sampling(logits, top_p, temp) else: - token = mx.random.categorical(logits * (1 / temp)) + token = categorical_sampling(logits, temp) return token, logprobs diff --git a/llms/tests/test_sample_utils.py b/llms/tests/test_sample_utils.py index 0bccdd07..ec0e2cb7 100644 --- a/llms/tests/test_sample_utils.py +++ b/llms/tests/test_sample_utils.py @@ -1,38 +1,32 @@ import unittest -from unittest.mock import patch import mlx.core as mx from mlx_lm.sample_utils import top_p_sampling class TestSamplingUtils(unittest.TestCase): - @patch("mlx.core.random.categorical") - def test_top_p_sampling(self, mock_categorical): - logits = mx.array([[1.0, 2.0, 3.0, 4.0]]) - top_p = 0.3 + def test_top_p_sampling(self): + probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] + logits = mx.log(probs) temperature = 1.0 - expected_token = mx.array([3]) - mock_categorical.return_value = expected_token - token = top_p_sampling(logits, top_p, temperature) - expected_top_probs = mx.array([[0.0, 0.0, 0.0, 0.643914]]) - self.assertTrue(mx.allclose(token, expected_token)) - args, _ = mock_categorical.call_args - self.assertTrue(args[0].shape == expected_top_probs.shape) - self.assertTrue(mx.allclose(args[0], mx.log(expected_top_probs))) + token = top_p_sampling(logits, 0.3, temperature).item() + self.assertEqual(token, 0) - logits = mx.array([[1.0, 2.0, 3.0, 4.0]]) - top_p = 0.9 - temperature = 1.0 - expected_token = mx.array([3]) - mock_categorical.return_value = expected_token + token = top_p_sampling(logits, 0.95, temperature).item() + self.assertTrue(token in (0, 3)) - token = top_p_sampling(logits, top_p, temperature) - expected_top_probs = mx.array([[0.0, 0.0871443, 0.236883, 0.643914]]) - self.assertTrue(mx.allclose(token, expected_token)) - args, _ = mock_categorical.call_args - self.assertTrue(args[0].shape == expected_top_probs.shape) - self.assertTrue(mx.allclose(args[0], mx.log(expected_top_probs))) + probs = mx.array([0.0, 0.5, 0.4, 0.1])[None] + logits = mx.log(probs) + + token = top_p_sampling(logits, 0.4, temperature).item() + self.assertEqual(token, 1) + + token = top_p_sampling(logits, 0.6, temperature).item() + self.assertTrue(token in (1, 2)) + + token = top_p_sampling(logits, 0.95, temperature).item() + self.assertTrue(token in (1, 2, 3)) if __name__ == "__main__":