2024-10-15 01:48:46 +08:00
|
|
|
# Copyright © 2024 Apple Inc.
|
|
|
|
|
|
|
|
import unittest
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
from huggingface_hub import snapshot_download
|
|
|
|
from mlx_lm.tokenizer_utils import (
|
|
|
|
BPEStreamingDetokenizer,
|
|
|
|
NaiveStreamingDetokenizer,
|
|
|
|
SPMStreamingDetokenizer,
|
|
|
|
load_tokenizer,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class TestTokenizers(unittest.TestCase):
|
|
|
|
|
|
|
|
def download_tokenizer(self, repo):
|
|
|
|
path = Path(
|
|
|
|
snapshot_download(
|
|
|
|
repo_id=repo,
|
|
|
|
allow_patterns=[
|
|
|
|
"tokenizer.json",
|
|
|
|
"tokenizer_config.json",
|
|
|
|
"special_tokens_map.json",
|
|
|
|
"tokenizer.model",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return load_tokenizer(path)
|
|
|
|
|
|
|
|
def check_tokenizer(self, tokenizer):
|
|
|
|
def check(tokens):
|
|
|
|
expected_text = tokenizer.decode(tokens)
|
|
|
|
detokenizer = tokenizer.detokenizer
|
|
|
|
detokenizer.reset()
|
|
|
|
text = ""
|
|
|
|
for t in tokens:
|
|
|
|
detokenizer.add_token(t)
|
|
|
|
seg = detokenizer.last_segment
|
|
|
|
text += seg
|
|
|
|
detokenizer.finalize()
|
|
|
|
text += detokenizer.last_segment
|
|
|
|
self.assertEqual(text, expected_text)
|
|
|
|
|
|
|
|
tokens = tokenizer.encode("a ,b")
|
|
|
|
check(tokens)
|
|
|
|
|
|
|
|
tokens = tokenizer.encode('{"why_its_funny" :"a_joke_explainer" ,"rating":3.5}')
|
|
|
|
check(tokens)
|
|
|
|
|
|
|
|
tokens = tokenizer.encode("3 3")
|
|
|
|
check(tokens)
|
|
|
|
|
2024-10-28 06:06:07 +08:00
|
|
|
tokens = tokenizer.encode("import 'package:flutter/material.dart';")
|
|
|
|
check(tokens)
|
|
|
|
|
2024-10-15 01:48:46 +08:00
|
|
|
def test_tokenizers(self):
|
|
|
|
tokenizer_repos = [
|
|
|
|
("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer),
|
|
|
|
("mlx-community/Mistral-7B-v0.2-4bit", SPMStreamingDetokenizer),
|
|
|
|
("mlx-community/Phi-3.5-mini-instruct-4bit", SPMStreamingDetokenizer),
|
|
|
|
("mlx-community/Mistral-7B-Instruct-v0.3", SPMStreamingDetokenizer),
|
|
|
|
("mlx-community/Llama-3.2-1B-Instruct-4bit", BPEStreamingDetokenizer),
|
|
|
|
]
|
|
|
|
for tokenizer_repo, expected_detokenizer in tokenizer_repos:
|
|
|
|
with self.subTest(tokenizer=tokenizer_repo):
|
|
|
|
tokenizer = self.download_tokenizer(tokenizer_repo)
|
|
|
|
tokenizer.decode([0, 1, 2])
|
|
|
|
self.assertTrue(isinstance(tokenizer.detokenizer, expected_detokenizer))
|
|
|
|
self.check_tokenizer(tokenizer)
|
|
|
|
|
|
|
|
# Try one with a naive detokenizer
|
|
|
|
tokenizer = self.download_tokenizer("mlx-community/Llama-3.2-1B-Instruct-4bit")
|
|
|
|
tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer)
|
|
|
|
self.check_tokenizer(tokenizer)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
unittest.main()
|