From 0ab01b4626cfca974ea8616370da2d0e3254a205 Mon Sep 17 00:00:00 2001 From: Anchen Date: Tue, 26 Mar 2024 09:07:55 +1100 Subject: [PATCH] fix(mlx-lm): sorted probs in top_p implementation. (#610) * fix(mlx-lm): the top p imp * chore: address comment --- llms/mlx_lm/sample_utils.py | 2 +- llms/tests/test_sample_utils.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index 2b793672..1953aeea 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -22,7 +22,7 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr # sort probs in ascending order sorted_indices = mx.argsort(probs, axis=-1) - sorted_probs = probs[..., sorted_indices] + sorted_probs = probs[..., sorted_indices.squeeze(0)] cumulative_probs = mx.cumsum(sorted_probs, axis=-1) diff --git a/llms/tests/test_sample_utils.py b/llms/tests/test_sample_utils.py index f02560a6..8b960736 100644 --- a/llms/tests/test_sample_utils.py +++ b/llms/tests/test_sample_utils.py @@ -18,6 +18,7 @@ class TestLora(unittest.TestCase): expected_top_probs = mx.array([[0.0, 0.0, 0.0, 0.643914]]) self.assertTrue(mx.allclose(token, expected_token)) args, _ = mock_categorical.call_args + self.assertTrue(args[0].shape == expected_top_probs.shape) self.assertTrue(mx.allclose(args[0], mx.log(expected_top_probs))) logits = mx.array([[1.0, 2.0, 3.0, 4.0]]) @@ -30,6 +31,7 @@ class TestLora(unittest.TestCase): expected_top_probs = mx.array([[0.0, 0.0871443, 0.236883, 0.643914]]) self.assertTrue(mx.allclose(token, expected_token)) args, _ = mock_categorical.call_args + self.assertTrue(args[0].shape == expected_top_probs.shape) self.assertTrue(mx.allclose(args[0], mx.log(expected_top_probs)))