From 6fd1f70f7366a1e55f14e2b4cd885b86875ab56c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 5 Nov 2024 06:06:26 -0800 Subject: [PATCH 1/2] fix spm decoder multi-byte (#1092) --- llms/mlx_lm/tokenizer_utils.py | 40 +++++++++++++++------------------- llms/tests/test_tokenizers.py | 3 +++ 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 568a672d..9d390733 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -6,12 +6,6 @@ from transformers import AutoTokenizer REPLACEMENT_CHAR = "\ufffd" -def _remove_space(x): - if x and x[0] == " ": - return x[1:] - return x - - class StreamingDetokenizer: """The streaming detokenizer interface so that we can detokenize one token at a time. @@ -123,42 +117,42 @@ class SPMStreamingDetokenizer(StreamingDetokenizer): def __init__(self, tokenizer, trim_space=True): self.trim_space = trim_space + self._sep = "\u2581".encode() # Extract the tokens in a list from id to text self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1) for value, tokenid in tokenizer.vocab.items(): - self.tokenmap[tokenid] = value - - # Replace bytes with their value - for i in range(len(self.tokenmap)): - if self.tokenmap[i].startswith("<0x"): - self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16)) + if value.startswith("<0x"): + # Replace bytes with their value + self.tokenmap[tokenid] = bytes([int(value[3:5], 16)]) + else: + self.tokenmap[tokenid] = value.encode() self.reset() def reset(self): self.offset = 0 - self._unflushed = "" + self._unflushed = b"" self.text = "" self.tokens = [] + def _flush(self): + text = self._unflushed.replace(self._sep, b" ").decode("utf-8") + if not self.text and self.trim_space and text and text[0] == " ": + text = text[1:] + self.text += text + def add_token(self, token): v = self.tokenmap[token] - if v[0] == "\u2581": - if self.text or not self.trim_space: - self.text += self._unflushed.replace("\u2581", " ") - else: - self.text = _remove_space(self._unflushed.replace("\u2581", " ")) + if v.startswith(self._sep): + self._flush() self._unflushed = v else: self._unflushed += v def finalize(self): - if self.text or not self.trim_space: - self.text += self._unflushed.replace("\u2581", " ") - else: - self.text = _remove_space(self._unflushed.replace("\u2581", " ")) - self._unflushed = "" + self._flush() + self._unflushed = b"" class BPEStreamingDetokenizer(StreamingDetokenizer): diff --git a/llms/tests/test_tokenizers.py b/llms/tests/test_tokenizers.py index 3c93fbe2..9c30d51e 100644 --- a/llms/tests/test_tokenizers.py +++ b/llms/tests/test_tokenizers.py @@ -42,6 +42,9 @@ class TestTokenizers(unittest.TestCase): text += detokenizer.last_segment self.assertEqual(text, expected_text) + tokens = tokenizer.encode("こんにちは!私の名前はAI") + check(tokens) + tokens = tokenizer.encode("a ,b") check(tokens) From ed9e81dd581a9505e677e12c025137d5326fe6df Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 5 Nov 2024 10:24:24 -0800 Subject: [PATCH 2/2] Fix rotating kv cache size (#1093) --- llms/mlx_lm/models/base.py | 2 +- llms/mlx_lm/models/cache.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index cda41c79..f02f49b1 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -42,7 +42,7 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None): if cache is not None and cache[0] is not None: c = cache[0] if hasattr(c, "max_size"): - offset = min(c.max_size - 1, c.offset) + offset = min(c.max_size, c.offset) window_size = c.max_size else: offset = c.offset diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index 1cd5289d..14026f0c 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -325,9 +325,9 @@ class RotatingKVCache(_BaseCache): self.keys = self._temporal_order(self.keys) self.values = self._temporal_order(self.values) - # The largest size is self.max_size + S - 1 to ensure + # The largest size is self.max_size + S to ensure # every token gets at least self.max_size context - trim_size = self._idx - self.max_size + 1 + trim_size = self._idx - self.max_size self.keys = self._trim(trim_size, self.keys, keys) self.values = self._trim(trim_size, self.values, values) self.offset += keys.shape[2]