diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index 23e08d97..d7049f7d 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -169,19 +169,18 @@ def min_p_sampling( @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) -def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array: +def top_p_sampling(logits: mx.array, top_p: float) -> mx.array: """ Apply top-p (nucleus) sampling to logits. Args: logits: The logits from the model's output. top_p: The cumulative probability threshold for top-p filtering. - temperature: Temperature parameter for softmax distribution reshaping. Returns: token selected based on the top-p criterion. """ # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460 - probs = mx.softmax(logits * (1 / temperature), axis=-1) + probs = mx.softmax(logits, axis=-1) # sort probs in ascending order sorted_indices = mx.argsort(probs, axis=-1) @@ -196,8 +195,15 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr 0, ) - sorted_tokens = mx.random.categorical(mx.log(top_probs), axis=-1)[:, None] - return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1) + # Create a mapping to rearrange back to original indices + # Use argsort of sorted_indices to get the inverse permutation + inverse_indices = mx.argsort(sorted_indices, axis=-1) + + # Rearrange top_probs back to original order + original_order_probs = mx.take_along_axis(top_probs, inverse_indices, axis=-1) + + # Convert back to logits and return + return mx.log(mx.where(original_order_probs > 0, original_order_probs, 0)) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) diff --git a/llms/tests/test_sample_utils.py b/llms/tests/test_sample_utils.py index f12abbf4..5a3d8847 100644 --- a/llms/tests/test_sample_utils.py +++ b/llms/tests/test_sample_utils.py @@ -8,31 +8,42 @@ class TestSampleUtils(unittest.TestCase): def test_top_p_sampling(self): probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] logits = mx.log(probs) - temperature = 1.0 - token = top_p_sampling(logits, 0.3, temperature).item() - self.assertEqual(token, 0) + actual_logits = top_p_sampling(logits, 0.3) + actual_probs = mx.softmax(actual_logits.squeeze()) + self.assertEqual(actual_probs.tolist(), [1.0, 0.0, 0.0, 0.0]) - token = top_p_sampling(logits, 0.95, temperature).item() - self.assertTrue(token in (0, 3)) + actual_logits = top_p_sampling(logits, 0.95) + actual_probs = mx.softmax(actual_logits.squeeze()) + self.assertEqual(probs.squeeze().tolist(), actual_probs.tolist()) probs = mx.array([0.0, 0.5, 0.4, 0.1])[None] logits = mx.log(probs) + actual_logits = top_p_sampling(logits, 0.4) + actual_probs = mx.softmax(actual_logits.squeeze()) + self.assertEqual(actual_probs.tolist(), [0.0, 1.0, 0.0, 0.0]) - token = top_p_sampling(logits, 0.4, temperature).item() - self.assertEqual(token, 1) + actual_logits = top_p_sampling(logits, 0.6) + actual_probs = mx.softmax(actual_logits.squeeze()) + self.assertEqual( + [round(p, 4) for p in actual_probs.tolist()], [0.0, 0.5556, 0.4444, 0.0] + ) - token = top_p_sampling(logits, 0.6, temperature).item() - self.assertTrue(token in (1, 2)) - - token = top_p_sampling(logits, 0.95, temperature).item() - self.assertTrue(token in (1, 2, 3)) + actual_logits = top_p_sampling(logits, 0.95) + actual_probs = mx.softmax(actual_logits.squeeze()) + actual_rounded = [round(p, 4) for p in actual_probs.tolist()] + expected_rounded = [0.0, 0.5, 0.4, 0.1] + self.assertEqual(actual_rounded, expected_rounded) + self.assertAlmostEqual(sum(actual_probs.tolist()), 1.0) # Batch mode works - probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]]) + probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.1, 0.1]]) logits = mx.log(probs) - tokens = top_p_sampling(logits, 0.5, temperature) - self.assertEqual(tokens.tolist(), [0, 1]) + actual_logits = top_p_sampling(logits, 0.5) + actual_probs = mx.softmax(actual_logits, axis=-1) + self.assertEqual( + actual_probs.tolist(), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]] + ) def test_min_p_sampling(self): probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]