mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
77 lines
2.5 KiB
Python
77 lines
2.5 KiB
Python
![]() |
# Copyright © 2024 Apple Inc.
|
||
|
|
||
|
import unittest
|
||
|
from pathlib import Path
|
||
|
|
||
|
from huggingface_hub import snapshot_download
|
||
|
from mlx_lm.tokenizer_utils import (
|
||
|
BPEStreamingDetokenizer,
|
||
|
NaiveStreamingDetokenizer,
|
||
|
SPMStreamingDetokenizer,
|
||
|
load_tokenizer,
|
||
|
)
|
||
|
|
||
|
|
||
|
class TestTokenizers(unittest.TestCase):
|
||
|
|
||
|
def download_tokenizer(self, repo):
|
||
|
path = Path(
|
||
|
snapshot_download(
|
||
|
repo_id=repo,
|
||
|
allow_patterns=[
|
||
|
"tokenizer.json",
|
||
|
"tokenizer_config.json",
|
||
|
"special_tokens_map.json",
|
||
|
"tokenizer.model",
|
||
|
],
|
||
|
)
|
||
|
)
|
||
|
return load_tokenizer(path)
|
||
|
|
||
|
def check_tokenizer(self, tokenizer):
|
||
|
def check(tokens):
|
||
|
expected_text = tokenizer.decode(tokens)
|
||
|
detokenizer = tokenizer.detokenizer
|
||
|
detokenizer.reset()
|
||
|
text = ""
|
||
|
for t in tokens:
|
||
|
detokenizer.add_token(t)
|
||
|
seg = detokenizer.last_segment
|
||
|
text += seg
|
||
|
detokenizer.finalize()
|
||
|
text += detokenizer.last_segment
|
||
|
self.assertEqual(text, expected_text)
|
||
|
|
||
|
tokens = tokenizer.encode("a ,b")
|
||
|
check(tokens)
|
||
|
|
||
|
tokens = tokenizer.encode('{"why_its_funny" :"a_joke_explainer" ,"rating":3.5}')
|
||
|
check(tokens)
|
||
|
|
||
|
tokens = tokenizer.encode("3 3")
|
||
|
check(tokens)
|
||
|
|
||
|
def test_tokenizers(self):
|
||
|
tokenizer_repos = [
|
||
|
("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer),
|
||
|
("mlx-community/Mistral-7B-v0.2-4bit", SPMStreamingDetokenizer),
|
||
|
("mlx-community/Phi-3.5-mini-instruct-4bit", SPMStreamingDetokenizer),
|
||
|
("mlx-community/Mistral-7B-Instruct-v0.3", SPMStreamingDetokenizer),
|
||
|
("mlx-community/Llama-3.2-1B-Instruct-4bit", BPEStreamingDetokenizer),
|
||
|
]
|
||
|
for tokenizer_repo, expected_detokenizer in tokenizer_repos:
|
||
|
with self.subTest(tokenizer=tokenizer_repo):
|
||
|
tokenizer = self.download_tokenizer(tokenizer_repo)
|
||
|
tokenizer.decode([0, 1, 2])
|
||
|
self.assertTrue(isinstance(tokenizer.detokenizer, expected_detokenizer))
|
||
|
self.check_tokenizer(tokenizer)
|
||
|
|
||
|
# Try one with a naive detokenizer
|
||
|
tokenizer = self.download_tokenizer("mlx-community/Llama-3.2-1B-Instruct-4bit")
|
||
|
tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer)
|
||
|
self.check_tokenizer(tokenizer)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
unittest.main()
|