batched min p and fix spec gen sampling (#1222)

This commit is contained in:
Awni Hannun
2025-01-27 15:40:31 -08:00
committed by GitHub
parent 77faa14ba4
commit f44a52e2dc
3 changed files with 24 additions and 13 deletions

View File

@@ -28,6 +28,12 @@ class TestSampleUtils(unittest.TestCase):
token = top_p_sampling(logits, 0.95, temperature).item()
self.assertTrue(token in (1, 2, 3))
# Batch mode works
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
logits = mx.log(probs)
tokens = top_p_sampling(logits, 0.5, temperature)
self.assertEqual(tokens.tolist(), [0, 1])
def test_min_p_sampling(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)
@@ -42,6 +48,12 @@ class TestSampleUtils(unittest.TestCase):
token = min_p_sampling(logits, 0.05)
self.assertTrue(token in (0, 3))
# Batch mode works
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
logits = mx.log(probs)
tokens = min_p_sampling(logits, 0.7)
self.assertEqual(tokens.tolist(), [0, 1])
def test_top_k_sampling(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)