mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
fix (#1079)
This commit is contained in:
parent
e510987870
commit
0f799947d0
@ -186,6 +186,8 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
# https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
# https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
||||||
self.make_byte_decoder()
|
self.make_byte_decoder()
|
||||||
|
|
||||||
|
self._added_ids = set(tokenizer.added_tokens_decoder.keys())
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.offset = 0
|
self.offset = 0
|
||||||
self._unflushed = ""
|
self._unflushed = ""
|
||||||
@ -205,11 +207,16 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
|
|
||||||
def add_token(self, token):
|
def add_token(self, token):
|
||||||
v = self.tokenmap[token]
|
v = self.tokenmap[token]
|
||||||
if self._byte_decoder[v[0]] == 32:
|
is_added = token in self._added_ids
|
||||||
|
if is_added or self._byte_decoder[v[0]] == 32:
|
||||||
current_text = bytearray(
|
current_text = bytearray(
|
||||||
self._byte_decoder[c] for c in self._unflushed
|
self._byte_decoder[c] for c in self._unflushed
|
||||||
).decode("utf-8")
|
).decode("utf-8")
|
||||||
self.text += self._maybe_trim_space(current_text)
|
self.text += self._maybe_trim_space(current_text)
|
||||||
|
if is_added:
|
||||||
|
self.text += v
|
||||||
|
self._unflushed = ""
|
||||||
|
else:
|
||||||
self._unflushed = v
|
self._unflushed = v
|
||||||
else:
|
else:
|
||||||
self._unflushed += v
|
self._unflushed += v
|
||||||
|
@ -74,6 +74,17 @@ class TestTokenizers(unittest.TestCase):
|
|||||||
tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer)
|
tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer)
|
||||||
self.check_tokenizer(tokenizer)
|
self.check_tokenizer(tokenizer)
|
||||||
|
|
||||||
|
def test_special_tokens(self):
|
||||||
|
tokenizer_repo = "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx"
|
||||||
|
tokenizer = self.download_tokenizer(tokenizer_repo)
|
||||||
|
|
||||||
|
detokenizer = tokenizer.detokenizer
|
||||||
|
detokenizer.reset()
|
||||||
|
detokenizer.add_token(tokenizer.eos_token_id)
|
||||||
|
detokenizer.finalize()
|
||||||
|
|
||||||
|
self.assertEqual(detokenizer.last_segment, tokenizer.eos_token)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user