Tokenizer updates + tests (#1024)

* tokenizer updates + tests

* nit

* add can_trim_prompt_cache

* nits
This commit is contained in:
Awni Hannun 2024-10-14 10:48:46 -07:00 committed by GitHub
parent 6c368f2124
commit 8dca1a2f60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 108 additions and 23 deletions

View File

@ -77,6 +77,13 @@ def load_prompt_cache(file_name, return_metadata=False):
return cache return cache
def can_trim_prompt_cache(cache: List[Any]) -> bool:
"""
Check if model's cache can be trimmed.
"""
return all(c.is_trimmable() for c in cache)
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
""" """
Trim the model's cache by the given number of tokens. Trim the model's cache by the given number of tokens.
@ -91,7 +98,7 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
Returns: Returns:
(int): The number of tokens that were trimmed. (int): The number of tokens that were trimmed.
""" """
if not all(c.is_trimmable() for c in cache) or len(cache) == 0: if not can_trim_prompt_cache(cache) or len(cache) == 0:
return 0 return 0
return [c.trim(num_tokens) for c in cache][0] return [c.trim(num_tokens) for c in cache][0]

View File

@ -220,17 +220,17 @@ class DeepseekV2Attention(nn.Module):
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1) k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
k_pe = mx.concatenate([k_pe] * self.num_heads, axis=1)
if cache is not None: if cache is not None:
q_pe = self.rope(q_pe, cache.offset) q_pe = self.rope(q_pe, cache.offset)
k_pe = self.rope(k_pe, cache.offset) k_pe = self.rope(k_pe, cache.offset)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys, values = cache.update_and_fetch( keys, values = cache.update_and_fetch(
mx.concatenate([k_nope, k_pe], axis=-1), values mx.concatenate([k_nope, k_pe], axis=-1), values
) )
else: else:
q_pe = self.rope(q_pe) q_pe = self.rope(q_pe)
k_pe = self.rope(k_pe) k_pe = self.rope(k_pe)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys = mx.concatenate([k_nope, k_pe], axis=-1) keys = mx.concatenate([k_nope, k_pe], axis=-1)
queries = mx.concatenate([q_nope, q_pe], axis=-1) queries = mx.concatenate([q_nope, q_pe], axis=-1)
@ -291,7 +291,7 @@ class MoEGate(nn.Module):
scores = scores.reshape(bsz, seq_len, -1) scores = scores.reshape(bsz, seq_len, -1)
k = self.top_k k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]) inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
scores = mx.take_along_axis(scores, inds, axis=-1) scores = mx.take_along_axis(scores, inds, axis=-1)
scores = scores * self.routed_scaling_factor scores = scores * self.routed_scaling_factor

View File

@ -97,6 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
def text(self): def text(self):
if self._current_tokens: if self._current_tokens:
self._current_text = self._tokenizer.decode(self._current_tokens) self._current_text = self._tokenizer.decode(self._current_tokens)
if (
self._tokenizer.clean_up_tokenization_spaces
and self._current_text[-1] == " "
):
self._current_text = self._current_text[:-1]
if self._current_text and self._current_text[-1] == "\n": if self._current_text and self._current_text[-1] == "\n":
self._tokens.extend(self._current_tokens) self._tokens.extend(self._current_tokens)
self._text += self._current_text self._text += self._current_text
@ -164,9 +169,11 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
""" """
_byte_decoder = None _byte_decoder = None
_space_matches = (".", "?", "!", ",", "'", "n't", "'m", "'s", "'ve", "'re")
def __init__(self, tokenizer, trim_space=False): def __init__(self, tokenizer):
self.trim_space = trim_space
self.clean_spaces = tokenizer.clean_up_tokenization_spaces
# Extract the tokens in a list from id to text # Extract the tokens in a list from id to text
self.tokenmap = [None] * len(tokenizer.vocab) self.tokenmap = [None] * len(tokenizer.vocab)
@ -185,17 +192,22 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
self.text = "" self.text = ""
self.tokens = [] self.tokens = []
def _maybe_trim_space(self, current_text):
if current_text[0] != " ":
return current_text
elif not self.text:
return current_text[1:]
elif self.clean_spaces and current_text[1:].startswith(self._space_matches):
return current_text[1:]
return current_text
def add_token(self, token): def add_token(self, token):
v = self.tokenmap[token] v = self.tokenmap[token]
# if the token starts with space
if self._byte_decoder[v[0]] == 32: if self._byte_decoder[v[0]] == 32:
current_text = bytearray( current_text = bytearray(
self._byte_decoder[c] for c in self._unflushed self._byte_decoder[c] for c in self._unflushed
).decode("utf-8") ).decode("utf-8")
if self.text or not self.trim_space: self.text += self._maybe_trim_space(current_text)
self.text += current_text
else:
self.text += _remove_space(current_text)
self._unflushed = v self._unflushed = v
else: else:
self._unflushed += v self._unflushed += v
@ -204,10 +216,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode( current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode(
"utf-8" "utf-8"
) )
if self.text or not self.trim_space: self.text += self._maybe_trim_space(current_text)
self.text += current_text
else:
self.text += _remove_space(current_text)
self._unflushed = "" self._unflushed = ""
@classmethod @classmethod
@ -303,14 +312,7 @@ def _is_spm_decoder_no_space(decoder):
def _is_bpe_decoder(decoder): def _is_bpe_decoder(decoder):
_target_description = { return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
"type": "ByteLevel",
"add_prefix_space": False,
"trim_offsets": False,
"use_regex": False,
}
return _match(_target_description, decoder)
def load_tokenizer(model_path, tokenizer_config_extra={}): def load_tokenizer(model_path, tokenizer_config_extra={}):

View File

@ -0,0 +1,76 @@
# Copyright © 2024 Apple Inc.
import unittest
from pathlib import Path
from huggingface_hub import snapshot_download
from mlx_lm.tokenizer_utils import (
BPEStreamingDetokenizer,
NaiveStreamingDetokenizer,
SPMStreamingDetokenizer,
load_tokenizer,
)
class TestTokenizers(unittest.TestCase):
def download_tokenizer(self, repo):
path = Path(
snapshot_download(
repo_id=repo,
allow_patterns=[
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"tokenizer.model",
],
)
)
return load_tokenizer(path)
def check_tokenizer(self, tokenizer):
def check(tokens):
expected_text = tokenizer.decode(tokens)
detokenizer = tokenizer.detokenizer
detokenizer.reset()
text = ""
for t in tokens:
detokenizer.add_token(t)
seg = detokenizer.last_segment
text += seg
detokenizer.finalize()
text += detokenizer.last_segment
self.assertEqual(text, expected_text)
tokens = tokenizer.encode("a ,b")
check(tokens)
tokens = tokenizer.encode('{"why_its_funny" :"a_joke_explainer" ,"rating":3.5}')
check(tokens)
tokens = tokenizer.encode("3 3")
check(tokens)
def test_tokenizers(self):
tokenizer_repos = [
("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer),
("mlx-community/Mistral-7B-v0.2-4bit", SPMStreamingDetokenizer),
("mlx-community/Phi-3.5-mini-instruct-4bit", SPMStreamingDetokenizer),
("mlx-community/Mistral-7B-Instruct-v0.3", SPMStreamingDetokenizer),
("mlx-community/Llama-3.2-1B-Instruct-4bit", BPEStreamingDetokenizer),
]
for tokenizer_repo, expected_detokenizer in tokenizer_repos:
with self.subTest(tokenizer=tokenizer_repo):
tokenizer = self.download_tokenizer(tokenizer_repo)
tokenizer.decode([0, 1, 2])
self.assertTrue(isinstance(tokenizer.detokenizer, expected_detokenizer))
self.check_tokenizer(tokenizer)
# Try one with a naive detokenizer
tokenizer = self.download_tokenizer("mlx-community/Llama-3.2-1B-Instruct-4bit")
tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer)
self.check_tokenizer(tokenizer)
if __name__ == "__main__":
unittest.main()