mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Tokenizer updates + tests (#1024)
* tokenizer updates + tests * nit * add can_trim_prompt_cache * nits
This commit is contained in:
parent
6c368f2124
commit
8dca1a2f60
@ -77,6 +77,13 @@ def load_prompt_cache(file_name, return_metadata=False):
|
||||
return cache
|
||||
|
||||
|
||||
def can_trim_prompt_cache(cache: List[Any]) -> bool:
|
||||
"""
|
||||
Check if model's cache can be trimmed.
|
||||
"""
|
||||
return all(c.is_trimmable() for c in cache)
|
||||
|
||||
|
||||
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
|
||||
"""
|
||||
Trim the model's cache by the given number of tokens.
|
||||
@ -91,7 +98,7 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
|
||||
Returns:
|
||||
(int): The number of tokens that were trimmed.
|
||||
"""
|
||||
if not all(c.is_trimmable() for c in cache) or len(cache) == 0:
|
||||
if not can_trim_prompt_cache(cache) or len(cache) == 0:
|
||||
return 0
|
||||
return [c.trim(num_tokens) for c in cache][0]
|
||||
|
||||
|
@ -220,17 +220,17 @@ class DeepseekV2Attention(nn.Module):
|
||||
|
||||
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
|
||||
|
||||
k_pe = mx.concatenate([k_pe] * self.num_heads, axis=1)
|
||||
|
||||
if cache is not None:
|
||||
q_pe = self.rope(q_pe, cache.offset)
|
||||
k_pe = self.rope(k_pe, cache.offset)
|
||||
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
|
||||
keys, values = cache.update_and_fetch(
|
||||
mx.concatenate([k_nope, k_pe], axis=-1), values
|
||||
)
|
||||
else:
|
||||
q_pe = self.rope(q_pe)
|
||||
k_pe = self.rope(k_pe)
|
||||
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
|
||||
keys = mx.concatenate([k_nope, k_pe], axis=-1)
|
||||
|
||||
queries = mx.concatenate([q_nope, q_pe], axis=-1)
|
||||
@ -291,7 +291,7 @@ class MoEGate(nn.Module):
|
||||
scores = scores.reshape(bsz, seq_len, -1)
|
||||
|
||||
k = self.top_k
|
||||
inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k])
|
||||
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
|
||||
scores = mx.take_along_axis(scores, inds, axis=-1)
|
||||
scores = scores * self.routed_scaling_factor
|
||||
|
||||
|
@ -97,6 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||
def text(self):
|
||||
if self._current_tokens:
|
||||
self._current_text = self._tokenizer.decode(self._current_tokens)
|
||||
if (
|
||||
self._tokenizer.clean_up_tokenization_spaces
|
||||
and self._current_text[-1] == " "
|
||||
):
|
||||
self._current_text = self._current_text[:-1]
|
||||
if self._current_text and self._current_text[-1] == "\n":
|
||||
self._tokens.extend(self._current_tokens)
|
||||
self._text += self._current_text
|
||||
@ -164,9 +169,11 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||
"""
|
||||
|
||||
_byte_decoder = None
|
||||
_space_matches = (".", "?", "!", ",", "'", "n't", "'m", "'s", "'ve", "'re")
|
||||
|
||||
def __init__(self, tokenizer, trim_space=False):
|
||||
self.trim_space = trim_space
|
||||
def __init__(self, tokenizer):
|
||||
|
||||
self.clean_spaces = tokenizer.clean_up_tokenization_spaces
|
||||
|
||||
# Extract the tokens in a list from id to text
|
||||
self.tokenmap = [None] * len(tokenizer.vocab)
|
||||
@ -185,17 +192,22 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||
self.text = ""
|
||||
self.tokens = []
|
||||
|
||||
def _maybe_trim_space(self, current_text):
|
||||
if current_text[0] != " ":
|
||||
return current_text
|
||||
elif not self.text:
|
||||
return current_text[1:]
|
||||
elif self.clean_spaces and current_text[1:].startswith(self._space_matches):
|
||||
return current_text[1:]
|
||||
return current_text
|
||||
|
||||
def add_token(self, token):
|
||||
v = self.tokenmap[token]
|
||||
# if the token starts with space
|
||||
if self._byte_decoder[v[0]] == 32:
|
||||
current_text = bytearray(
|
||||
self._byte_decoder[c] for c in self._unflushed
|
||||
).decode("utf-8")
|
||||
if self.text or not self.trim_space:
|
||||
self.text += current_text
|
||||
else:
|
||||
self.text += _remove_space(current_text)
|
||||
self.text += self._maybe_trim_space(current_text)
|
||||
self._unflushed = v
|
||||
else:
|
||||
self._unflushed += v
|
||||
@ -204,10 +216,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||
current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode(
|
||||
"utf-8"
|
||||
)
|
||||
if self.text or not self.trim_space:
|
||||
self.text += current_text
|
||||
else:
|
||||
self.text += _remove_space(current_text)
|
||||
self.text += self._maybe_trim_space(current_text)
|
||||
self._unflushed = ""
|
||||
|
||||
@classmethod
|
||||
@ -303,14 +312,7 @@ def _is_spm_decoder_no_space(decoder):
|
||||
|
||||
|
||||
def _is_bpe_decoder(decoder):
|
||||
_target_description = {
|
||||
"type": "ByteLevel",
|
||||
"add_prefix_space": False,
|
||||
"trim_offsets": False,
|
||||
"use_regex": False,
|
||||
}
|
||||
|
||||
return _match(_target_description, decoder)
|
||||
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
|
||||
|
||||
|
||||
def load_tokenizer(model_path, tokenizer_config_extra={}):
|
||||
|
76
llms/tests/test_tokenizers.py
Normal file
76
llms/tests/test_tokenizers.py
Normal file
@ -0,0 +1,76 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx_lm.tokenizer_utils import (
|
||||
BPEStreamingDetokenizer,
|
||||
NaiveStreamingDetokenizer,
|
||||
SPMStreamingDetokenizer,
|
||||
load_tokenizer,
|
||||
)
|
||||
|
||||
|
||||
class TestTokenizers(unittest.TestCase):
|
||||
|
||||
def download_tokenizer(self, repo):
|
||||
path = Path(
|
||||
snapshot_download(
|
||||
repo_id=repo,
|
||||
allow_patterns=[
|
||||
"tokenizer.json",
|
||||
"tokenizer_config.json",
|
||||
"special_tokens_map.json",
|
||||
"tokenizer.model",
|
||||
],
|
||||
)
|
||||
)
|
||||
return load_tokenizer(path)
|
||||
|
||||
def check_tokenizer(self, tokenizer):
|
||||
def check(tokens):
|
||||
expected_text = tokenizer.decode(tokens)
|
||||
detokenizer = tokenizer.detokenizer
|
||||
detokenizer.reset()
|
||||
text = ""
|
||||
for t in tokens:
|
||||
detokenizer.add_token(t)
|
||||
seg = detokenizer.last_segment
|
||||
text += seg
|
||||
detokenizer.finalize()
|
||||
text += detokenizer.last_segment
|
||||
self.assertEqual(text, expected_text)
|
||||
|
||||
tokens = tokenizer.encode("a ,b")
|
||||
check(tokens)
|
||||
|
||||
tokens = tokenizer.encode('{"why_its_funny" :"a_joke_explainer" ,"rating":3.5}')
|
||||
check(tokens)
|
||||
|
||||
tokens = tokenizer.encode("3 3")
|
||||
check(tokens)
|
||||
|
||||
def test_tokenizers(self):
|
||||
tokenizer_repos = [
|
||||
("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer),
|
||||
("mlx-community/Mistral-7B-v0.2-4bit", SPMStreamingDetokenizer),
|
||||
("mlx-community/Phi-3.5-mini-instruct-4bit", SPMStreamingDetokenizer),
|
||||
("mlx-community/Mistral-7B-Instruct-v0.3", SPMStreamingDetokenizer),
|
||||
("mlx-community/Llama-3.2-1B-Instruct-4bit", BPEStreamingDetokenizer),
|
||||
]
|
||||
for tokenizer_repo, expected_detokenizer in tokenizer_repos:
|
||||
with self.subTest(tokenizer=tokenizer_repo):
|
||||
tokenizer = self.download_tokenizer(tokenizer_repo)
|
||||
tokenizer.decode([0, 1, 2])
|
||||
self.assertTrue(isinstance(tokenizer.detokenizer, expected_detokenizer))
|
||||
self.check_tokenizer(tokenizer)
|
||||
|
||||
# Try one with a naive detokenizer
|
||||
tokenizer = self.download_tokenizer("mlx-community/Llama-3.2-1B-Instruct-4bit")
|
||||
tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer)
|
||||
self.check_tokenizer(tokenizer)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user