fix spm decoder multi-byte (#1092)

This commit is contained in:
Awni Hannun 2024-11-05 06:06:26 -08:00 committed by GitHub
parent 4394633ce0
commit 6fd1f70f73
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 23 deletions

View File

@ -6,12 +6,6 @@ from transformers import AutoTokenizer
REPLACEMENT_CHAR = "\ufffd" REPLACEMENT_CHAR = "\ufffd"
def _remove_space(x):
if x and x[0] == " ":
return x[1:]
return x
class StreamingDetokenizer: class StreamingDetokenizer:
"""The streaming detokenizer interface so that we can detokenize one token at a time. """The streaming detokenizer interface so that we can detokenize one token at a time.
@ -123,42 +117,42 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
def __init__(self, tokenizer, trim_space=True): def __init__(self, tokenizer, trim_space=True):
self.trim_space = trim_space self.trim_space = trim_space
self._sep = "\u2581".encode()
# Extract the tokens in a list from id to text # Extract the tokens in a list from id to text
self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1) self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1)
for value, tokenid in tokenizer.vocab.items(): for value, tokenid in tokenizer.vocab.items():
self.tokenmap[tokenid] = value if value.startswith("<0x"):
# Replace bytes with their value
# Replace bytes with their value self.tokenmap[tokenid] = bytes([int(value[3:5], 16)])
for i in range(len(self.tokenmap)): else:
if self.tokenmap[i].startswith("<0x"): self.tokenmap[tokenid] = value.encode()
self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16))
self.reset() self.reset()
def reset(self): def reset(self):
self.offset = 0 self.offset = 0
self._unflushed = "" self._unflushed = b""
self.text = "" self.text = ""
self.tokens = [] self.tokens = []
def _flush(self):
text = self._unflushed.replace(self._sep, b" ").decode("utf-8")
if not self.text and self.trim_space and text and text[0] == " ":
text = text[1:]
self.text += text
def add_token(self, token): def add_token(self, token):
v = self.tokenmap[token] v = self.tokenmap[token]
if v[0] == "\u2581": if v.startswith(self._sep):
if self.text or not self.trim_space: self._flush()
self.text += self._unflushed.replace("\u2581", " ")
else:
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
self._unflushed = v self._unflushed = v
else: else:
self._unflushed += v self._unflushed += v
def finalize(self): def finalize(self):
if self.text or not self.trim_space: self._flush()
self.text += self._unflushed.replace("\u2581", " ") self._unflushed = b""
else:
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
self._unflushed = ""
class BPEStreamingDetokenizer(StreamingDetokenizer): class BPEStreamingDetokenizer(StreamingDetokenizer):

View File

@ -42,6 +42,9 @@ class TestTokenizers(unittest.TestCase):
text += detokenizer.last_segment text += detokenizer.last_segment
self.assertEqual(text, expected_text) self.assertEqual(text, expected_text)
tokens = tokenizer.encode("こんにちは私の名前はAI")
check(tokens)
tokens = tokenizer.encode("a ,b") tokens = tokenizer.encode("a ,b")
check(tokens) check(tokens)