mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-31 10:58:07 +08:00 
			
		
		
		
	fix(mlx-lm): sorted probs in top_p implementation. (#610)
* fix(mlx-lm): the top p imp * chore: address comment
This commit is contained in:
		| @@ -22,7 +22,7 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr | ||||
|  | ||||
|     # sort probs in ascending order | ||||
|     sorted_indices = mx.argsort(probs, axis=-1) | ||||
|     sorted_probs = probs[..., sorted_indices] | ||||
|     sorted_probs = probs[..., sorted_indices.squeeze(0)] | ||||
|  | ||||
|     cumulative_probs = mx.cumsum(sorted_probs, axis=-1) | ||||
|  | ||||
|   | ||||
| @@ -18,6 +18,7 @@ class TestLora(unittest.TestCase): | ||||
|         expected_top_probs = mx.array([[0.0, 0.0, 0.0, 0.643914]]) | ||||
|         self.assertTrue(mx.allclose(token, expected_token)) | ||||
|         args, _ = mock_categorical.call_args | ||||
|         self.assertTrue(args[0].shape == expected_top_probs.shape) | ||||
|         self.assertTrue(mx.allclose(args[0], mx.log(expected_top_probs))) | ||||
|  | ||||
|         logits = mx.array([[1.0, 2.0, 3.0, 4.0]]) | ||||
| @@ -30,6 +31,7 @@ class TestLora(unittest.TestCase): | ||||
|         expected_top_probs = mx.array([[0.0, 0.0871443, 0.236883, 0.643914]]) | ||||
|         self.assertTrue(mx.allclose(token, expected_token)) | ||||
|         args, _ = mock_categorical.call_args | ||||
|         self.assertTrue(args[0].shape == expected_top_probs.shape) | ||||
|         self.assertTrue(mx.allclose(args[0], mx.log(expected_top_probs))) | ||||
|  | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Anchen
					Anchen