From 6fd1f70f7366a1e55f14e2b4cd885b86875ab56c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 5 Nov 2024 06:06:26 -0800 Subject: [PATCH] fix spm decoder multi-byte (#1092) --- llms/mlx_lm/tokenizer_utils.py | 40 +++++++++++++++------------------- llms/tests/test_tokenizers.py | 3 +++ 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 568a672d..9d390733 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -6,12 +6,6 @@ from transformers import AutoTokenizer REPLACEMENT_CHAR = "\ufffd" -def _remove_space(x): - if x and x[0] == " ": - return x[1:] - return x - - class StreamingDetokenizer: """The streaming detokenizer interface so that we can detokenize one token at a time. @@ -123,42 +117,42 @@ class SPMStreamingDetokenizer(StreamingDetokenizer): def __init__(self, tokenizer, trim_space=True): self.trim_space = trim_space + self._sep = "\u2581".encode() # Extract the tokens in a list from id to text self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1) for value, tokenid in tokenizer.vocab.items(): - self.tokenmap[tokenid] = value - - # Replace bytes with their value - for i in range(len(self.tokenmap)): - if self.tokenmap[i].startswith("<0x"): - self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16)) + if value.startswith("<0x"): + # Replace bytes with their value + self.tokenmap[tokenid] = bytes([int(value[3:5], 16)]) + else: + self.tokenmap[tokenid] = value.encode() self.reset() def reset(self): self.offset = 0 - self._unflushed = "" + self._unflushed = b"" self.text = "" self.tokens = [] + def _flush(self): + text = self._unflushed.replace(self._sep, b" ").decode("utf-8") + if not self.text and self.trim_space and text and text[0] == " ": + text = text[1:] + self.text += text + def add_token(self, token): v = self.tokenmap[token] - if v[0] == "\u2581": - if self.text or not self.trim_space: - self.text += self._unflushed.replace("\u2581", " ") - else: - self.text = _remove_space(self._unflushed.replace("\u2581", " ")) + if v.startswith(self._sep): + self._flush() self._unflushed = v else: self._unflushed += v def finalize(self): - if self.text or not self.trim_space: - self.text += self._unflushed.replace("\u2581", " ") - else: - self.text = _remove_space(self._unflushed.replace("\u2581", " ")) - self._unflushed = "" + self._flush() + self._unflushed = b"" class BPEStreamingDetokenizer(StreamingDetokenizer): diff --git a/llms/tests/test_tokenizers.py b/llms/tests/test_tokenizers.py index 3c93fbe2..9c30d51e 100644 --- a/llms/tests/test_tokenizers.py +++ b/llms/tests/test_tokenizers.py @@ -42,6 +42,9 @@ class TestTokenizers(unittest.TestCase): text += detokenizer.last_segment self.assertEqual(text, expected_text) + tokens = tokenizer.encode("こんにちは!私の名前はAI") + check(tokens) + tokens = tokenizer.encode("a ,b") check(tokens)