mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-23 22:18:06 +08:00
fix(mlx-lm): sorted probs in top_p implementation. (#610)
* fix(mlx-lm): the top p imp * chore: address comment
This commit is contained in:
@@ -18,6 +18,7 @@ class TestLora(unittest.TestCase):
|
||||
expected_top_probs = mx.array([[0.0, 0.0, 0.0, 0.643914]])
|
||||
self.assertTrue(mx.allclose(token, expected_token))
|
||||
args, _ = mock_categorical.call_args
|
||||
self.assertTrue(args[0].shape == expected_top_probs.shape)
|
||||
self.assertTrue(mx.allclose(args[0], mx.log(expected_top_probs)))
|
||||
|
||||
logits = mx.array([[1.0, 2.0, 3.0, 4.0]])
|
||||
@@ -30,6 +31,7 @@ class TestLora(unittest.TestCase):
|
||||
expected_top_probs = mx.array([[0.0, 0.0871443, 0.236883, 0.643914]])
|
||||
self.assertTrue(mx.allclose(token, expected_token))
|
||||
args, _ = mock_categorical.call_args
|
||||
self.assertTrue(args[0].shape == expected_top_probs.shape)
|
||||
self.assertTrue(mx.allclose(args[0], mx.log(expected_top_probs)))
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user