fix + test

This commit is contained in:
Awni Hannun 2024-12-17 09:09:17 -08:00
parent 8c67480050
commit 34c9cfcc3c
2 changed files with 25 additions and 14 deletions

View File

@ -174,14 +174,22 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
# https://github.com/openai/gpt-2/blob/master/src/encoder.py
self.make_byte_decoder()
self._added_ids = set(tokenizer.added_tokens_decoder.keys())
def reset(self):
self.offset = 0
self._unflushed = ""
self.text = ""
self.tokens = []
def _decode_bytes(self, seq):
barr = bytearray()
for c in seq:
res = self._byte_decoder.get(c, False)
if res:
barr.append(res)
else:
barr.extend(bytes(c, "utf-8"))
return barr.decode("utf-8", "replace")
def _maybe_trim_space(self, current_text):
if len(current_text) == 0:
return current_text
@ -193,21 +201,20 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
return current_text[1:]
return current_text
def _is_single_space(self, v):
return
def add_token(self, token):
self.tokens.append(token)
v = self.tokenmap[token]
is_added = token in self._added_ids
if not is_added:
self._unflushed += v
text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode(
"utf-8", "replace"
)
if is_added:
# We need to manually encode and decode the added tokens in case special characters
# used for `\n` / `\t` have been manually added in the added tokens
v = self.tokenizer.decode(self.tokenizer.encode(v))
text += v
if not text.endswith("\ufffd"):
self._unflushed += v
text = self._decode_bytes(self._unflushed)
# For multi-byte utf-8 wait until they are complete
# For single spaces wait until the next token to clean it if needed
if not text.endswith("\ufffd") and not (
len(v) == 1 and self._byte_decoder[v[0]] == 32
):
self.text += self._maybe_trim_space(text)
self._unflushed = ""

View File

@ -58,6 +58,9 @@ class TestTokenizers(unittest.TestCase):
tokens = tokenizer.encode("import 'package:flutter/material.dart';")
check(tokens)
tokens = tokenizer.encode("hello\nworld")
check(tokens)
def test_tokenizers(self):
tokenizer_repos = [
("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer),
@ -65,6 +68,7 @@ class TestTokenizers(unittest.TestCase):
("mlx-community/Phi-3.5-mini-instruct-4bit", SPMStreamingDetokenizer),
("mlx-community/Mistral-7B-Instruct-v0.3", SPMStreamingDetokenizer),
("mlx-community/Llama-3.2-1B-Instruct-4bit", BPEStreamingDetokenizer),
("mlx-community/Falcon3-7B-Instruct-4bit", BPEStreamingDetokenizer),
]
for tokenizer_repo, expected_detokenizer in tokenizer_repos:
with self.subTest(tokenizer=tokenizer_repo):