From 8dca1a2f6091f443ffb54b5f39a390f6971e677c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 14 Oct 2024 10:48:46 -0700 Subject: [PATCH] Tokenizer updates + tests (#1024) * tokenizer updates + tests * nit * add can_trim_prompt_cache * nits --- llms/mlx_lm/models/cache.py | 9 +++- llms/mlx_lm/models/deepseek_v2.py | 6 +-- llms/mlx_lm/tokenizer_utils.py | 40 ++++++++-------- llms/tests/test_tokenizers.py | 76 +++++++++++++++++++++++++++++++ 4 files changed, 108 insertions(+), 23 deletions(-) create mode 100644 llms/tests/test_tokenizers.py diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index b06422e5..a6a56e0a 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -77,6 +77,13 @@ def load_prompt_cache(file_name, return_metadata=False): return cache +def can_trim_prompt_cache(cache: List[Any]) -> bool: + """ + Check if model's cache can be trimmed. + """ + return all(c.is_trimmable() for c in cache) + + def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: """ Trim the model's cache by the given number of tokens. @@ -91,7 +98,7 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: Returns: (int): The number of tokens that were trimmed. """ - if not all(c.is_trimmable() for c in cache) or len(cache) == 0: + if not can_trim_prompt_cache(cache) or len(cache) == 0: return 0 return [c.trim(num_tokens) for c in cache][0] diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index 17d061a8..bb3e5184 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -220,17 +220,17 @@ class DeepseekV2Attention(nn.Module): k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1) - k_pe = mx.concatenate([k_pe] * self.num_heads, axis=1) - if cache is not None: q_pe = self.rope(q_pe, cache.offset) k_pe = self.rope(k_pe, cache.offset) + k_pe = mx.repeat(k_pe, self.num_heads, axis=1) keys, values = cache.update_and_fetch( mx.concatenate([k_nope, k_pe], axis=-1), values ) else: q_pe = self.rope(q_pe) k_pe = self.rope(k_pe) + k_pe = mx.repeat(k_pe, self.num_heads, axis=1) keys = mx.concatenate([k_nope, k_pe], axis=-1) queries = mx.concatenate([q_nope, q_pe], axis=-1) @@ -291,7 +291,7 @@ class MoEGate(nn.Module): scores = scores.reshape(bsz, seq_len, -1) k = self.top_k - inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]) + inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] scores = mx.take_along_axis(scores, inds, axis=-1) scores = scores * self.routed_scaling_factor diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 04bbbcc5..d8694d86 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -97,6 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer): def text(self): if self._current_tokens: self._current_text = self._tokenizer.decode(self._current_tokens) + if ( + self._tokenizer.clean_up_tokenization_spaces + and self._current_text[-1] == " " + ): + self._current_text = self._current_text[:-1] if self._current_text and self._current_text[-1] == "\n": self._tokens.extend(self._current_tokens) self._text += self._current_text @@ -164,9 +169,11 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): """ _byte_decoder = None + _space_matches = (".", "?", "!", ",", "'", "n't", "'m", "'s", "'ve", "'re") - def __init__(self, tokenizer, trim_space=False): - self.trim_space = trim_space + def __init__(self, tokenizer): + + self.clean_spaces = tokenizer.clean_up_tokenization_spaces # Extract the tokens in a list from id to text self.tokenmap = [None] * len(tokenizer.vocab) @@ -185,17 +192,22 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): self.text = "" self.tokens = [] + def _maybe_trim_space(self, current_text): + if current_text[0] != " ": + return current_text + elif not self.text: + return current_text[1:] + elif self.clean_spaces and current_text[1:].startswith(self._space_matches): + return current_text[1:] + return current_text + def add_token(self, token): v = self.tokenmap[token] - # if the token starts with space if self._byte_decoder[v[0]] == 32: current_text = bytearray( self._byte_decoder[c] for c in self._unflushed ).decode("utf-8") - if self.text or not self.trim_space: - self.text += current_text - else: - self.text += _remove_space(current_text) + self.text += self._maybe_trim_space(current_text) self._unflushed = v else: self._unflushed += v @@ -204,10 +216,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode( "utf-8" ) - if self.text or not self.trim_space: - self.text += current_text - else: - self.text += _remove_space(current_text) + self.text += self._maybe_trim_space(current_text) self._unflushed = "" @classmethod @@ -303,14 +312,7 @@ def _is_spm_decoder_no_space(decoder): def _is_bpe_decoder(decoder): - _target_description = { - "type": "ByteLevel", - "add_prefix_space": False, - "trim_offsets": False, - "use_regex": False, - } - - return _match(_target_description, decoder) + return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel" def load_tokenizer(model_path, tokenizer_config_extra={}): diff --git a/llms/tests/test_tokenizers.py b/llms/tests/test_tokenizers.py new file mode 100644 index 00000000..7b4828b1 --- /dev/null +++ b/llms/tests/test_tokenizers.py @@ -0,0 +1,76 @@ +# Copyright © 2024 Apple Inc. + +import unittest +from pathlib import Path + +from huggingface_hub import snapshot_download +from mlx_lm.tokenizer_utils import ( + BPEStreamingDetokenizer, + NaiveStreamingDetokenizer, + SPMStreamingDetokenizer, + load_tokenizer, +) + + +class TestTokenizers(unittest.TestCase): + + def download_tokenizer(self, repo): + path = Path( + snapshot_download( + repo_id=repo, + allow_patterns=[ + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "tokenizer.model", + ], + ) + ) + return load_tokenizer(path) + + def check_tokenizer(self, tokenizer): + def check(tokens): + expected_text = tokenizer.decode(tokens) + detokenizer = tokenizer.detokenizer + detokenizer.reset() + text = "" + for t in tokens: + detokenizer.add_token(t) + seg = detokenizer.last_segment + text += seg + detokenizer.finalize() + text += detokenizer.last_segment + self.assertEqual(text, expected_text) + + tokens = tokenizer.encode("a ,b") + check(tokens) + + tokens = tokenizer.encode('{"why_its_funny" :"a_joke_explainer" ,"rating":3.5}') + check(tokens) + + tokens = tokenizer.encode("3 3") + check(tokens) + + def test_tokenizers(self): + tokenizer_repos = [ + ("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer), + ("mlx-community/Mistral-7B-v0.2-4bit", SPMStreamingDetokenizer), + ("mlx-community/Phi-3.5-mini-instruct-4bit", SPMStreamingDetokenizer), + ("mlx-community/Mistral-7B-Instruct-v0.3", SPMStreamingDetokenizer), + ("mlx-community/Llama-3.2-1B-Instruct-4bit", BPEStreamingDetokenizer), + ] + for tokenizer_repo, expected_detokenizer in tokenizer_repos: + with self.subTest(tokenizer=tokenizer_repo): + tokenizer = self.download_tokenizer(tokenizer_repo) + tokenizer.decode([0, 1, 2]) + self.assertTrue(isinstance(tokenizer.detokenizer, expected_detokenizer)) + self.check_tokenizer(tokenizer) + + # Try one with a naive detokenizer + tokenizer = self.download_tokenizer("mlx-community/Llama-3.2-1B-Instruct-4bit") + tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer) + self.check_tokenizer(tokenizer) + + +if __name__ == "__main__": + unittest.main()