mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Tokenizer updates + tests (#1024)
* tokenizer updates + tests * nit * add can_trim_prompt_cache * nits
This commit is contained in:
parent
6c368f2124
commit
8dca1a2f60
@ -77,6 +77,13 @@ def load_prompt_cache(file_name, return_metadata=False):
|
|||||||
return cache
|
return cache
|
||||||
|
|
||||||
|
|
||||||
|
def can_trim_prompt_cache(cache: List[Any]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if model's cache can be trimmed.
|
||||||
|
"""
|
||||||
|
return all(c.is_trimmable() for c in cache)
|
||||||
|
|
||||||
|
|
||||||
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
|
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
|
||||||
"""
|
"""
|
||||||
Trim the model's cache by the given number of tokens.
|
Trim the model's cache by the given number of tokens.
|
||||||
@ -91,7 +98,7 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
|
|||||||
Returns:
|
Returns:
|
||||||
(int): The number of tokens that were trimmed.
|
(int): The number of tokens that were trimmed.
|
||||||
"""
|
"""
|
||||||
if not all(c.is_trimmable() for c in cache) or len(cache) == 0:
|
if not can_trim_prompt_cache(cache) or len(cache) == 0:
|
||||||
return 0
|
return 0
|
||||||
return [c.trim(num_tokens) for c in cache][0]
|
return [c.trim(num_tokens) for c in cache][0]
|
||||||
|
|
||||||
|
@ -220,17 +220,17 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
|
|
||||||
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
|
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
|
||||||
|
|
||||||
k_pe = mx.concatenate([k_pe] * self.num_heads, axis=1)
|
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
q_pe = self.rope(q_pe, cache.offset)
|
q_pe = self.rope(q_pe, cache.offset)
|
||||||
k_pe = self.rope(k_pe, cache.offset)
|
k_pe = self.rope(k_pe, cache.offset)
|
||||||
|
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
|
||||||
keys, values = cache.update_and_fetch(
|
keys, values = cache.update_and_fetch(
|
||||||
mx.concatenate([k_nope, k_pe], axis=-1), values
|
mx.concatenate([k_nope, k_pe], axis=-1), values
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q_pe = self.rope(q_pe)
|
q_pe = self.rope(q_pe)
|
||||||
k_pe = self.rope(k_pe)
|
k_pe = self.rope(k_pe)
|
||||||
|
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
|
||||||
keys = mx.concatenate([k_nope, k_pe], axis=-1)
|
keys = mx.concatenate([k_nope, k_pe], axis=-1)
|
||||||
|
|
||||||
queries = mx.concatenate([q_nope, q_pe], axis=-1)
|
queries = mx.concatenate([q_nope, q_pe], axis=-1)
|
||||||
@ -291,7 +291,7 @@ class MoEGate(nn.Module):
|
|||||||
scores = scores.reshape(bsz, seq_len, -1)
|
scores = scores.reshape(bsz, seq_len, -1)
|
||||||
|
|
||||||
k = self.top_k
|
k = self.top_k
|
||||||
inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k])
|
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
|
||||||
scores = mx.take_along_axis(scores, inds, axis=-1)
|
scores = mx.take_along_axis(scores, inds, axis=-1)
|
||||||
scores = scores * self.routed_scaling_factor
|
scores = scores * self.routed_scaling_factor
|
||||||
|
|
||||||
|
@ -97,6 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
def text(self):
|
def text(self):
|
||||||
if self._current_tokens:
|
if self._current_tokens:
|
||||||
self._current_text = self._tokenizer.decode(self._current_tokens)
|
self._current_text = self._tokenizer.decode(self._current_tokens)
|
||||||
|
if (
|
||||||
|
self._tokenizer.clean_up_tokenization_spaces
|
||||||
|
and self._current_text[-1] == " "
|
||||||
|
):
|
||||||
|
self._current_text = self._current_text[:-1]
|
||||||
if self._current_text and self._current_text[-1] == "\n":
|
if self._current_text and self._current_text[-1] == "\n":
|
||||||
self._tokens.extend(self._current_tokens)
|
self._tokens.extend(self._current_tokens)
|
||||||
self._text += self._current_text
|
self._text += self._current_text
|
||||||
@ -164,9 +169,11 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_byte_decoder = None
|
_byte_decoder = None
|
||||||
|
_space_matches = (".", "?", "!", ",", "'", "n't", "'m", "'s", "'ve", "'re")
|
||||||
|
|
||||||
def __init__(self, tokenizer, trim_space=False):
|
def __init__(self, tokenizer):
|
||||||
self.trim_space = trim_space
|
|
||||||
|
self.clean_spaces = tokenizer.clean_up_tokenization_spaces
|
||||||
|
|
||||||
# Extract the tokens in a list from id to text
|
# Extract the tokens in a list from id to text
|
||||||
self.tokenmap = [None] * len(tokenizer.vocab)
|
self.tokenmap = [None] * len(tokenizer.vocab)
|
||||||
@ -185,17 +192,22 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
self.text = ""
|
self.text = ""
|
||||||
self.tokens = []
|
self.tokens = []
|
||||||
|
|
||||||
|
def _maybe_trim_space(self, current_text):
|
||||||
|
if current_text[0] != " ":
|
||||||
|
return current_text
|
||||||
|
elif not self.text:
|
||||||
|
return current_text[1:]
|
||||||
|
elif self.clean_spaces and current_text[1:].startswith(self._space_matches):
|
||||||
|
return current_text[1:]
|
||||||
|
return current_text
|
||||||
|
|
||||||
def add_token(self, token):
|
def add_token(self, token):
|
||||||
v = self.tokenmap[token]
|
v = self.tokenmap[token]
|
||||||
# if the token starts with space
|
|
||||||
if self._byte_decoder[v[0]] == 32:
|
if self._byte_decoder[v[0]] == 32:
|
||||||
current_text = bytearray(
|
current_text = bytearray(
|
||||||
self._byte_decoder[c] for c in self._unflushed
|
self._byte_decoder[c] for c in self._unflushed
|
||||||
).decode("utf-8")
|
).decode("utf-8")
|
||||||
if self.text or not self.trim_space:
|
self.text += self._maybe_trim_space(current_text)
|
||||||
self.text += current_text
|
|
||||||
else:
|
|
||||||
self.text += _remove_space(current_text)
|
|
||||||
self._unflushed = v
|
self._unflushed = v
|
||||||
else:
|
else:
|
||||||
self._unflushed += v
|
self._unflushed += v
|
||||||
@ -204,10 +216,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode(
|
current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode(
|
||||||
"utf-8"
|
"utf-8"
|
||||||
)
|
)
|
||||||
if self.text or not self.trim_space:
|
self.text += self._maybe_trim_space(current_text)
|
||||||
self.text += current_text
|
|
||||||
else:
|
|
||||||
self.text += _remove_space(current_text)
|
|
||||||
self._unflushed = ""
|
self._unflushed = ""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -303,14 +312,7 @@ def _is_spm_decoder_no_space(decoder):
|
|||||||
|
|
||||||
|
|
||||||
def _is_bpe_decoder(decoder):
|
def _is_bpe_decoder(decoder):
|
||||||
_target_description = {
|
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
|
||||||
"type": "ByteLevel",
|
|
||||||
"add_prefix_space": False,
|
|
||||||
"trim_offsets": False,
|
|
||||||
"use_regex": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
return _match(_target_description, decoder)
|
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(model_path, tokenizer_config_extra={}):
|
def load_tokenizer(model_path, tokenizer_config_extra={}):
|
||||||
|
76
llms/tests/test_tokenizers.py
Normal file
76
llms/tests/test_tokenizers.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from mlx_lm.tokenizer_utils import (
|
||||||
|
BPEStreamingDetokenizer,
|
||||||
|
NaiveStreamingDetokenizer,
|
||||||
|
SPMStreamingDetokenizer,
|
||||||
|
load_tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenizers(unittest.TestCase):
|
||||||
|
|
||||||
|
def download_tokenizer(self, repo):
|
||||||
|
path = Path(
|
||||||
|
snapshot_download(
|
||||||
|
repo_id=repo,
|
||||||
|
allow_patterns=[
|
||||||
|
"tokenizer.json",
|
||||||
|
"tokenizer_config.json",
|
||||||
|
"special_tokens_map.json",
|
||||||
|
"tokenizer.model",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return load_tokenizer(path)
|
||||||
|
|
||||||
|
def check_tokenizer(self, tokenizer):
|
||||||
|
def check(tokens):
|
||||||
|
expected_text = tokenizer.decode(tokens)
|
||||||
|
detokenizer = tokenizer.detokenizer
|
||||||
|
detokenizer.reset()
|
||||||
|
text = ""
|
||||||
|
for t in tokens:
|
||||||
|
detokenizer.add_token(t)
|
||||||
|
seg = detokenizer.last_segment
|
||||||
|
text += seg
|
||||||
|
detokenizer.finalize()
|
||||||
|
text += detokenizer.last_segment
|
||||||
|
self.assertEqual(text, expected_text)
|
||||||
|
|
||||||
|
tokens = tokenizer.encode("a ,b")
|
||||||
|
check(tokens)
|
||||||
|
|
||||||
|
tokens = tokenizer.encode('{"why_its_funny" :"a_joke_explainer" ,"rating":3.5}')
|
||||||
|
check(tokens)
|
||||||
|
|
||||||
|
tokens = tokenizer.encode("3 3")
|
||||||
|
check(tokens)
|
||||||
|
|
||||||
|
def test_tokenizers(self):
|
||||||
|
tokenizer_repos = [
|
||||||
|
("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer),
|
||||||
|
("mlx-community/Mistral-7B-v0.2-4bit", SPMStreamingDetokenizer),
|
||||||
|
("mlx-community/Phi-3.5-mini-instruct-4bit", SPMStreamingDetokenizer),
|
||||||
|
("mlx-community/Mistral-7B-Instruct-v0.3", SPMStreamingDetokenizer),
|
||||||
|
("mlx-community/Llama-3.2-1B-Instruct-4bit", BPEStreamingDetokenizer),
|
||||||
|
]
|
||||||
|
for tokenizer_repo, expected_detokenizer in tokenizer_repos:
|
||||||
|
with self.subTest(tokenizer=tokenizer_repo):
|
||||||
|
tokenizer = self.download_tokenizer(tokenizer_repo)
|
||||||
|
tokenizer.decode([0, 1, 2])
|
||||||
|
self.assertTrue(isinstance(tokenizer.detokenizer, expected_detokenizer))
|
||||||
|
self.check_tokenizer(tokenizer)
|
||||||
|
|
||||||
|
# Try one with a naive detokenizer
|
||||||
|
tokenizer = self.download_tokenizer("mlx-community/Llama-3.2-1B-Instruct-4bit")
|
||||||
|
tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer)
|
||||||
|
self.check_tokenizer(tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user