mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
fix spm decoder multi-byte (#1092)
This commit is contained in:
parent
4394633ce0
commit
6fd1f70f73
@ -6,12 +6,6 @@ from transformers import AutoTokenizer
|
|||||||
REPLACEMENT_CHAR = "\ufffd"
|
REPLACEMENT_CHAR = "\ufffd"
|
||||||
|
|
||||||
|
|
||||||
def _remove_space(x):
|
|
||||||
if x and x[0] == " ":
|
|
||||||
return x[1:]
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class StreamingDetokenizer:
|
class StreamingDetokenizer:
|
||||||
"""The streaming detokenizer interface so that we can detokenize one token at a time.
|
"""The streaming detokenizer interface so that we can detokenize one token at a time.
|
||||||
|
|
||||||
@ -123,42 +117,42 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
|
|
||||||
def __init__(self, tokenizer, trim_space=True):
|
def __init__(self, tokenizer, trim_space=True):
|
||||||
self.trim_space = trim_space
|
self.trim_space = trim_space
|
||||||
|
self._sep = "\u2581".encode()
|
||||||
|
|
||||||
# Extract the tokens in a list from id to text
|
# Extract the tokens in a list from id to text
|
||||||
self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1)
|
self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1)
|
||||||
for value, tokenid in tokenizer.vocab.items():
|
for value, tokenid in tokenizer.vocab.items():
|
||||||
self.tokenmap[tokenid] = value
|
if value.startswith("<0x"):
|
||||||
|
# Replace bytes with their value
|
||||||
# Replace bytes with their value
|
self.tokenmap[tokenid] = bytes([int(value[3:5], 16)])
|
||||||
for i in range(len(self.tokenmap)):
|
else:
|
||||||
if self.tokenmap[i].startswith("<0x"):
|
self.tokenmap[tokenid] = value.encode()
|
||||||
self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16))
|
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.offset = 0
|
self.offset = 0
|
||||||
self._unflushed = ""
|
self._unflushed = b""
|
||||||
self.text = ""
|
self.text = ""
|
||||||
self.tokens = []
|
self.tokens = []
|
||||||
|
|
||||||
|
def _flush(self):
|
||||||
|
text = self._unflushed.replace(self._sep, b" ").decode("utf-8")
|
||||||
|
if not self.text and self.trim_space and text and text[0] == " ":
|
||||||
|
text = text[1:]
|
||||||
|
self.text += text
|
||||||
|
|
||||||
def add_token(self, token):
|
def add_token(self, token):
|
||||||
v = self.tokenmap[token]
|
v = self.tokenmap[token]
|
||||||
if v[0] == "\u2581":
|
if v.startswith(self._sep):
|
||||||
if self.text or not self.trim_space:
|
self._flush()
|
||||||
self.text += self._unflushed.replace("\u2581", " ")
|
|
||||||
else:
|
|
||||||
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
|
|
||||||
self._unflushed = v
|
self._unflushed = v
|
||||||
else:
|
else:
|
||||||
self._unflushed += v
|
self._unflushed += v
|
||||||
|
|
||||||
def finalize(self):
|
def finalize(self):
|
||||||
if self.text or not self.trim_space:
|
self._flush()
|
||||||
self.text += self._unflushed.replace("\u2581", " ")
|
self._unflushed = b""
|
||||||
else:
|
|
||||||
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
|
|
||||||
self._unflushed = ""
|
|
||||||
|
|
||||||
|
|
||||||
class BPEStreamingDetokenizer(StreamingDetokenizer):
|
class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||||
|
@ -42,6 +42,9 @@ class TestTokenizers(unittest.TestCase):
|
|||||||
text += detokenizer.last_segment
|
text += detokenizer.last_segment
|
||||||
self.assertEqual(text, expected_text)
|
self.assertEqual(text, expected_text)
|
||||||
|
|
||||||
|
tokens = tokenizer.encode("こんにちは!私の名前はAI")
|
||||||
|
check(tokens)
|
||||||
|
|
||||||
tokens = tokenizer.encode("a ,b")
|
tokens = tokenizer.encode("a ,b")
|
||||||
check(tokens)
|
check(tokens)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user