Generation refactor: part 2 (#1099)

* unify with stream_generate

* fixes

* nit

* some cleanup, warnings, tests

* fix test + faster min p + test

* version
This commit is contained in:
Awni Hannun
2024-11-23 11:47:06 -08:00
committed by GitHub
parent 004eb4cc9d
commit 0f135396ae
13 changed files with 184 additions and 197 deletions

View File

@@ -2,6 +2,7 @@
import unittest
from mlx_lm.sample_utils import make_logits_processors
from mlx_lm.utils import generate, load
@@ -25,8 +26,8 @@ class TestGenerate(unittest.TestCase):
self.tokenizer,
"hello",
max_tokens=5,
logits_processors=make_logits_processors(logit_bias),
verbose=False,
logit_bias=logit_bias,
)
self.assertEqual(text, "!!!!!")

View File

@@ -1,10 +1,10 @@
import unittest
import mlx.core as mx
from mlx_lm.sample_utils import top_p_sampling
from mlx_lm.sample_utils import min_p_sampling, top_p_sampling
class TestSamplingUtils(unittest.TestCase):
class TestSampleUtils(unittest.TestCase):
def test_top_p_sampling(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)
@@ -28,6 +28,20 @@ class TestSamplingUtils(unittest.TestCase):
token = top_p_sampling(logits, 0.95, temperature).item()
self.assertTrue(token in (1, 2, 3))
def test_min_p_sampling(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)
temperature = 1.0
token = min_p_sampling(logits, 0.8)
self.assertEqual(token, 0)
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)
temperature = 1.0
for _ in range(5):
token = min_p_sampling(logits, 0.05)
self.assertTrue(token in (0, 3))
if __name__ == "__main__":
unittest.main()

View File

@@ -34,10 +34,11 @@ class TestTokenizers(unittest.TestCase):
detokenizer = tokenizer.detokenizer
detokenizer.reset()
text = ""
for t in tokens:
for e, t in enumerate(tokens):
detokenizer.add_token(t)
seg = detokenizer.last_segment
text += seg
self.assertEqual(detokenizer.tokens, tokens[: e + 1])
detokenizer.finalize()
text += detokenizer.last_segment
self.assertEqual(text, expected_text)