fix(mlx-lm): sorted probs in top_p implementation. (#610)

* fix(mlx-lm): the top p imp

* chore: address comment
This commit is contained in:
Anchen 2024-03-26 09:07:55 +11:00 committed by GitHub
parent bbfcc103d7
commit 0ab01b4626
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 1 deletions

View File

@ -22,7 +22,7 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
# sort probs in ascending order
sorted_indices = mx.argsort(probs, axis=-1)
sorted_probs = probs[..., sorted_indices]
sorted_probs = probs[..., sorted_indices.squeeze(0)]
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)

View File

@ -18,6 +18,7 @@ class TestLora(unittest.TestCase):
expected_top_probs = mx.array([[0.0, 0.0, 0.0, 0.643914]])
self.assertTrue(mx.allclose(token, expected_token))
args, _ = mock_categorical.call_args
self.assertTrue(args[0].shape == expected_top_probs.shape)
self.assertTrue(mx.allclose(args[0], mx.log(expected_top_probs)))
logits = mx.array([[1.0, 2.0, 3.0, 4.0]])
@ -30,6 +31,7 @@ class TestLora(unittest.TestCase):
expected_top_probs = mx.array([[0.0, 0.0871443, 0.236883, 0.643914]])
self.assertTrue(mx.allclose(token, expected_token))
args, _ = mock_categorical.call_args
self.assertTrue(args[0].shape == expected_top_probs.shape)
self.assertTrue(mx.allclose(args[0], mx.log(expected_top_probs)))