mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Merge remote-tracking branch 'origin/completion_only' into completion_only
This commit is contained in:
commit
e0d66f5479
@ -42,7 +42,7 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
||||
if cache is not None and cache[0] is not None:
|
||||
c = cache[0]
|
||||
if hasattr(c, "max_size"):
|
||||
offset = min(c.max_size - 1, c.offset)
|
||||
offset = min(c.max_size, c.offset)
|
||||
window_size = c.max_size
|
||||
else:
|
||||
offset = c.offset
|
||||
|
@ -325,9 +325,9 @@ class RotatingKVCache(_BaseCache):
|
||||
self.keys = self._temporal_order(self.keys)
|
||||
self.values = self._temporal_order(self.values)
|
||||
|
||||
# The largest size is self.max_size + S - 1 to ensure
|
||||
# The largest size is self.max_size + S to ensure
|
||||
# every token gets at least self.max_size context
|
||||
trim_size = self._idx - self.max_size + 1
|
||||
trim_size = self._idx - self.max_size
|
||||
self.keys = self._trim(trim_size, self.keys, keys)
|
||||
self.values = self._trim(trim_size, self.values, values)
|
||||
self.offset += keys.shape[2]
|
||||
|
@ -6,12 +6,6 @@ from transformers import AutoTokenizer
|
||||
REPLACEMENT_CHAR = "\ufffd"
|
||||
|
||||
|
||||
def _remove_space(x):
|
||||
if x and x[0] == " ":
|
||||
return x[1:]
|
||||
return x
|
||||
|
||||
|
||||
class StreamingDetokenizer:
|
||||
"""The streaming detokenizer interface so that we can detokenize one token at a time.
|
||||
|
||||
@ -123,42 +117,42 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||
|
||||
def __init__(self, tokenizer, trim_space=True):
|
||||
self.trim_space = trim_space
|
||||
self._sep = "\u2581".encode()
|
||||
|
||||
# Extract the tokens in a list from id to text
|
||||
self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1)
|
||||
for value, tokenid in tokenizer.vocab.items():
|
||||
self.tokenmap[tokenid] = value
|
||||
|
||||
# Replace bytes with their value
|
||||
for i in range(len(self.tokenmap)):
|
||||
if self.tokenmap[i].startswith("<0x"):
|
||||
self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16))
|
||||
if value.startswith("<0x"):
|
||||
# Replace bytes with their value
|
||||
self.tokenmap[tokenid] = bytes([int(value[3:5], 16)])
|
||||
else:
|
||||
self.tokenmap[tokenid] = value.encode()
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.offset = 0
|
||||
self._unflushed = ""
|
||||
self._unflushed = b""
|
||||
self.text = ""
|
||||
self.tokens = []
|
||||
|
||||
def _flush(self):
|
||||
text = self._unflushed.replace(self._sep, b" ").decode("utf-8")
|
||||
if not self.text and self.trim_space and text and text[0] == " ":
|
||||
text = text[1:]
|
||||
self.text += text
|
||||
|
||||
def add_token(self, token):
|
||||
v = self.tokenmap[token]
|
||||
if v[0] == "\u2581":
|
||||
if self.text or not self.trim_space:
|
||||
self.text += self._unflushed.replace("\u2581", " ")
|
||||
else:
|
||||
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
|
||||
if v.startswith(self._sep):
|
||||
self._flush()
|
||||
self._unflushed = v
|
||||
else:
|
||||
self._unflushed += v
|
||||
|
||||
def finalize(self):
|
||||
if self.text or not self.trim_space:
|
||||
self.text += self._unflushed.replace("\u2581", " ")
|
||||
else:
|
||||
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
|
||||
self._unflushed = ""
|
||||
self._flush()
|
||||
self._unflushed = b""
|
||||
|
||||
|
||||
class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||
|
@ -42,6 +42,9 @@ class TestTokenizers(unittest.TestCase):
|
||||
text += detokenizer.last_segment
|
||||
self.assertEqual(text, expected_text)
|
||||
|
||||
tokens = tokenizer.encode("こんにちは!私の名前はAI")
|
||||
check(tokens)
|
||||
|
||||
tokens = tokenizer.encode("a ,b")
|
||||
check(tokens)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user