diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 0cbc3b9b..568a672d 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -186,6 +186,8 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): # https://github.com/openai/gpt-2/blob/master/src/encoder.py self.make_byte_decoder() + self._added_ids = set(tokenizer.added_tokens_decoder.keys()) + def reset(self): self.offset = 0 self._unflushed = "" @@ -205,12 +207,17 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): def add_token(self, token): v = self.tokenmap[token] - if self._byte_decoder[v[0]] == 32: + is_added = token in self._added_ids + if is_added or self._byte_decoder[v[0]] == 32: current_text = bytearray( self._byte_decoder[c] for c in self._unflushed ).decode("utf-8") self.text += self._maybe_trim_space(current_text) - self._unflushed = v + if is_added: + self.text += v + self._unflushed = "" + else: + self._unflushed = v else: self._unflushed += v diff --git a/llms/tests/test_tokenizers.py b/llms/tests/test_tokenizers.py index 03445c1f..3c93fbe2 100644 --- a/llms/tests/test_tokenizers.py +++ b/llms/tests/test_tokenizers.py @@ -74,6 +74,17 @@ class TestTokenizers(unittest.TestCase): tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer) self.check_tokenizer(tokenizer) + def test_special_tokens(self): + tokenizer_repo = "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx" + tokenizer = self.download_tokenizer(tokenizer_repo) + + detokenizer = tokenizer.detokenizer + detokenizer.reset() + detokenizer.add_token(tokenizer.eos_token_id) + detokenizer.finalize() + + self.assertEqual(detokenizer.last_segment, tokenizer.eos_token) + if __name__ == "__main__": unittest.main()