From fbed720d6f9ac7c854c1fdd8d9954fa1ba47691c Mon Sep 17 00:00:00 2001
From: Anchen
Date: Fri, 22 Mar 2024 06:18:23 +1100
Subject: [PATCH] chore(mlx-lm): fix the top_p implementation. (#602)
* chore(mlx-lm): clean up the top p imp
* chore: clean up
* chore: add test
* chore: address comments
* chore: clean up docs string
* chore: clean up test
---
llms/mlx_lm/sample_utils.py | 39 +++++++++++++++++++++++++++++++++
llms/mlx_lm/utils.py | 20 +++--------------
llms/tests/test_sample_utils.py | 37 +++++++++++++++++++++++++++++++
3 files changed, 79 insertions(+), 17 deletions(-)
create mode 100644 llms/mlx_lm/sample_utils.py
create mode 100644 llms/tests/test_sample_utils.py
diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py
new file mode 100644
index 00000000..2b793672
--- /dev/null
+++ b/llms/mlx_lm/sample_utils.py
@@ -0,0 +1,39 @@
+import mlx.core as mx
+
+
+def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
+ """
+ Apply top-p (nucleus) sampling to logits.
+
+ Args:
+ logits: The logits from the model's output.
+ top_p: The cumulative probability threshold for top-p filtering.
+ temperature: Temperature parameter for softmax distribution reshaping.
+ Returns:
+ token selected based on the top-p criterion.
+ """
+ if (
+ logits.dtype == mx.bfloat16
+ ): # workaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16
+ logits = logits.astype(mx.float32)
+
+ # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
+ probs = mx.softmax(logits / temperature, axis=-1)
+
+ # sort probs in ascending order
+ sorted_indices = mx.argsort(probs, axis=-1)
+ sorted_probs = probs[..., sorted_indices]
+
+ cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
+
+ # select tokens with cumulative probs below threshold
+ top_probs = mx.where(
+ cumulative_probs > 1 - top_p,
+ sorted_probs,
+ mx.zeros_like(sorted_probs),
+ )
+
+ sorted_token = mx.random.categorical(mx.log(top_probs))
+ token = sorted_indices.squeeze(0)[sorted_token]
+
+ return token
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index 7b0e2da7..03e0fbd3 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -17,6 +17,8 @@ from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
+from .sample_utils import top_p_sampling
+
# Local imports
from .tuner.utils import apply_lora_layers
from .tuner.utils import dequantize as dequantize_model
@@ -144,23 +146,7 @@ def generate_step(
token = mx.argmax(logits, axis=-1)
else:
if top_p > 0 and top_p < 1.0:
- if (
- logits.dtype == mx.bfloat16
- ): # workdaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16
- logits = logits.astype(mx.float32)
- probs = mx.softmax(logits / temp, axis=-1)
-
- sorted_probs = mx.sort(probs)[::-1]
- sorted_indices = mx.argsort(probs)[::-1]
- cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
-
- top_probs = mx.where(
- cumulative_probs > 1 - top_p,
- sorted_probs,
- mx.zeros_like(sorted_probs),
- )
- sorted_token = mx.random.categorical(mx.log(top_probs))
- token = sorted_indices.squeeze(0)[sorted_token]
+ token = top_p_sampling(logits, top_p, temp)
else:
token = mx.random.categorical(logits * (1 / temp))
diff --git a/llms/tests/test_sample_utils.py b/llms/tests/test_sample_utils.py
new file mode 100644
index 00000000..f02560a6
--- /dev/null
+++ b/llms/tests/test_sample_utils.py
@@ -0,0 +1,37 @@
+import unittest
+from unittest.mock import patch
+
+import mlx.core as mx
+from mlx_lm.sample_utils import top_p_sampling
+
+
+class TestLora(unittest.TestCase):
+ @patch("mlx.core.random.categorical")
+ def test_top_p_sampling(self, mock_categorical):
+ logits = mx.array([[1.0, 2.0, 3.0, 4.0]])
+ top_p = 0.3
+ temperature = 1.0
+ expected_token = mx.array([3])
+ mock_categorical.return_value = expected_token
+
+ token = top_p_sampling(logits, top_p, temperature)
+ expected_top_probs = mx.array([[0.0, 0.0, 0.0, 0.643914]])
+ self.assertTrue(mx.allclose(token, expected_token))
+ args, _ = mock_categorical.call_args
+ self.assertTrue(mx.allclose(args[0], mx.log(expected_top_probs)))
+
+ logits = mx.array([[1.0, 2.0, 3.0, 4.0]])
+ top_p = 0.9
+ temperature = 1.0
+ expected_token = mx.array([3])
+ mock_categorical.return_value = expected_token
+
+ token = top_p_sampling(logits, top_p, temperature)
+ expected_top_probs = mx.array([[0.0, 0.0871443, 0.236883, 0.643914]])
+ self.assertTrue(mx.allclose(token, expected_token))
+ args, _ = mock_categorical.call_args
+ self.assertTrue(mx.allclose(args[0], mx.log(expected_top_probs)))
+
+
+if __name__ == "__main__":
+ unittest.main()