mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Tokenizer updates + tests (#1024)
* tokenizer updates + tests * nit * add can_trim_prompt_cache * nits
This commit is contained in:
76
llms/tests/test_tokenizers.py
Normal file
76
llms/tests/test_tokenizers.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx_lm.tokenizer_utils import (
|
||||
BPEStreamingDetokenizer,
|
||||
NaiveStreamingDetokenizer,
|
||||
SPMStreamingDetokenizer,
|
||||
load_tokenizer,
|
||||
)
|
||||
|
||||
|
||||
class TestTokenizers(unittest.TestCase):
|
||||
|
||||
def download_tokenizer(self, repo):
|
||||
path = Path(
|
||||
snapshot_download(
|
||||
repo_id=repo,
|
||||
allow_patterns=[
|
||||
"tokenizer.json",
|
||||
"tokenizer_config.json",
|
||||
"special_tokens_map.json",
|
||||
"tokenizer.model",
|
||||
],
|
||||
)
|
||||
)
|
||||
return load_tokenizer(path)
|
||||
|
||||
def check_tokenizer(self, tokenizer):
|
||||
def check(tokens):
|
||||
expected_text = tokenizer.decode(tokens)
|
||||
detokenizer = tokenizer.detokenizer
|
||||
detokenizer.reset()
|
||||
text = ""
|
||||
for t in tokens:
|
||||
detokenizer.add_token(t)
|
||||
seg = detokenizer.last_segment
|
||||
text += seg
|
||||
detokenizer.finalize()
|
||||
text += detokenizer.last_segment
|
||||
self.assertEqual(text, expected_text)
|
||||
|
||||
tokens = tokenizer.encode("a ,b")
|
||||
check(tokens)
|
||||
|
||||
tokens = tokenizer.encode('{"why_its_funny" :"a_joke_explainer" ,"rating":3.5}')
|
||||
check(tokens)
|
||||
|
||||
tokens = tokenizer.encode("3 3")
|
||||
check(tokens)
|
||||
|
||||
def test_tokenizers(self):
|
||||
tokenizer_repos = [
|
||||
("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer),
|
||||
("mlx-community/Mistral-7B-v0.2-4bit", SPMStreamingDetokenizer),
|
||||
("mlx-community/Phi-3.5-mini-instruct-4bit", SPMStreamingDetokenizer),
|
||||
("mlx-community/Mistral-7B-Instruct-v0.3", SPMStreamingDetokenizer),
|
||||
("mlx-community/Llama-3.2-1B-Instruct-4bit", BPEStreamingDetokenizer),
|
||||
]
|
||||
for tokenizer_repo, expected_detokenizer in tokenizer_repos:
|
||||
with self.subTest(tokenizer=tokenizer_repo):
|
||||
tokenizer = self.download_tokenizer(tokenizer_repo)
|
||||
tokenizer.decode([0, 1, 2])
|
||||
self.assertTrue(isinstance(tokenizer.detokenizer, expected_detokenizer))
|
||||
self.check_tokenizer(tokenizer)
|
||||
|
||||
# Try one with a naive detokenizer
|
||||
tokenizer = self.download_tokenizer("mlx-community/Llama-3.2-1B-Instruct-4bit")
|
||||
tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer)
|
||||
self.check_tokenizer(tokenizer)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user