mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Fix decoding manually added tokens (#1164)
* Fix decoding manually added tokens * fix + test * nit * nit * no lag bpe --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
dfa4dd6c93
commit
845efddc8c
@ -127,23 +127,23 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
self.text = ""
|
self.text = ""
|
||||||
self.tokens = []
|
self.tokens = []
|
||||||
|
|
||||||
def _flush(self):
|
def _try_flush(self, force=False):
|
||||||
text = self._unflushed.replace(self._sep, b" ").decode("utf-8", "replace")
|
text = self._unflushed.replace(self._sep, b" ").decode("utf-8", "replace")
|
||||||
|
if not force and text.endswith("\ufffd"):
|
||||||
|
return
|
||||||
if not self.text and self.trim_space and text and text[0] == " ":
|
if not self.text and self.trim_space and text and text[0] == " ":
|
||||||
text = text[1:]
|
text = text[1:]
|
||||||
self.text += text
|
self.text += text
|
||||||
|
self._unflushed = b""
|
||||||
|
|
||||||
def add_token(self, token):
|
def add_token(self, token):
|
||||||
self.tokens.append(token)
|
self.tokens.append(token)
|
||||||
v = self.tokenmap[token]
|
v = self.tokenmap[token]
|
||||||
if v.startswith(self._sep):
|
self._unflushed += v
|
||||||
self._flush()
|
self._try_flush()
|
||||||
self._unflushed = v
|
|
||||||
else:
|
|
||||||
self._unflushed += v
|
|
||||||
|
|
||||||
def finalize(self):
|
def finalize(self):
|
||||||
self._flush()
|
self._try_flush(force=True)
|
||||||
self._unflushed = b""
|
self._unflushed = b""
|
||||||
|
|
||||||
|
|
||||||
@ -158,7 +158,6 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
_space_matches = (".", "?", "!", ",", "n't", "'m", "'s", "'ve", "'re")
|
_space_matches = (".", "?", "!", ",", "n't", "'m", "'s", "'ve", "'re")
|
||||||
|
|
||||||
def __init__(self, tokenizer):
|
def __init__(self, tokenizer):
|
||||||
|
|
||||||
self.clean_spaces = tokenizer.clean_up_tokenization_spaces
|
self.clean_spaces = tokenizer.clean_up_tokenization_spaces
|
||||||
|
|
||||||
# Extract the tokens in a list from id to text
|
# Extract the tokens in a list from id to text
|
||||||
@ -172,14 +171,22 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
# https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
# https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
||||||
self.make_byte_decoder()
|
self.make_byte_decoder()
|
||||||
|
|
||||||
self._added_ids = set(tokenizer.added_tokens_decoder.keys())
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.offset = 0
|
self.offset = 0
|
||||||
self._unflushed = ""
|
self._unflushed = ""
|
||||||
self.text = ""
|
self.text = ""
|
||||||
self.tokens = []
|
self.tokens = []
|
||||||
|
|
||||||
|
def _decode_bytes(self, seq):
|
||||||
|
barr = bytearray()
|
||||||
|
for c in seq:
|
||||||
|
res = self._byte_decoder.get(c, False)
|
||||||
|
if res:
|
||||||
|
barr.append(res)
|
||||||
|
else:
|
||||||
|
barr.extend(bytes(c, "utf-8"))
|
||||||
|
return barr.decode("utf-8", "replace")
|
||||||
|
|
||||||
def _maybe_trim_space(self, current_text):
|
def _maybe_trim_space(self, current_text):
|
||||||
if len(current_text) == 0:
|
if len(current_text) == 0:
|
||||||
return current_text
|
return current_text
|
||||||
@ -194,15 +201,14 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
def add_token(self, token):
|
def add_token(self, token):
|
||||||
self.tokens.append(token)
|
self.tokens.append(token)
|
||||||
v = self.tokenmap[token]
|
v = self.tokenmap[token]
|
||||||
is_added = token in self._added_ids
|
self._unflushed += v
|
||||||
if not is_added:
|
text = self._decode_bytes(self._unflushed)
|
||||||
self._unflushed += v
|
|
||||||
text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode(
|
# For multi-byte utf-8 wait until they are complete
|
||||||
"utf-8", "replace"
|
# For single spaces wait until the next token to clean it if needed
|
||||||
)
|
if not text.endswith("\ufffd") and not (
|
||||||
if is_added:
|
len(v) == 1 and self._byte_decoder[v[0]] == 32
|
||||||
text += v
|
):
|
||||||
if not text.endswith("\ufffd"):
|
|
||||||
self.text += self._maybe_trim_space(text)
|
self.text += self._maybe_trim_space(text)
|
||||||
self._unflushed = ""
|
self._unflushed = ""
|
||||||
|
|
||||||
|
@ -58,6 +58,9 @@ class TestTokenizers(unittest.TestCase):
|
|||||||
tokens = tokenizer.encode("import 'package:flutter/material.dart';")
|
tokens = tokenizer.encode("import 'package:flutter/material.dart';")
|
||||||
check(tokens)
|
check(tokens)
|
||||||
|
|
||||||
|
tokens = tokenizer.encode("hello\nworld")
|
||||||
|
check(tokens)
|
||||||
|
|
||||||
def test_tokenizers(self):
|
def test_tokenizers(self):
|
||||||
tokenizer_repos = [
|
tokenizer_repos = [
|
||||||
("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer),
|
("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer),
|
||||||
@ -65,6 +68,7 @@ class TestTokenizers(unittest.TestCase):
|
|||||||
("mlx-community/Phi-3.5-mini-instruct-4bit", SPMStreamingDetokenizer),
|
("mlx-community/Phi-3.5-mini-instruct-4bit", SPMStreamingDetokenizer),
|
||||||
("mlx-community/Mistral-7B-Instruct-v0.3", SPMStreamingDetokenizer),
|
("mlx-community/Mistral-7B-Instruct-v0.3", SPMStreamingDetokenizer),
|
||||||
("mlx-community/Llama-3.2-1B-Instruct-4bit", BPEStreamingDetokenizer),
|
("mlx-community/Llama-3.2-1B-Instruct-4bit", BPEStreamingDetokenizer),
|
||||||
|
("mlx-community/Falcon3-7B-Instruct-4bit", BPEStreamingDetokenizer),
|
||||||
]
|
]
|
||||||
for tokenizer_repo, expected_detokenizer in tokenizer_repos:
|
for tokenizer_repo, expected_detokenizer in tokenizer_repos:
|
||||||
with self.subTest(tokenizer=tokenizer_repo):
|
with self.subTest(tokenizer=tokenizer_repo):
|
||||||
|
Loading…
Reference in New Issue
Block a user