mlx-examples/llms/mlx_lm/tokenizer_utils.py
2025-01-05 22:26:05 -08:00

371 lines
12 KiB
Python

import json
from functools import partial
from transformers import AutoTokenizer
class StreamingDetokenizer:
"""The streaming detokenizer interface so that we can detokenize one token at a time.
Example usage is as follows:
detokenizer = ...
# Reset the tokenizer state
detokenizer.reset()
for token in generate(...):
detokenizer.add_token(token.item())
# Contains the whole text so far. Some tokens may not be included
# since it contains whole words usually.
detokenizer.text
# Contains the printable segment (usually a word) since the last
# time it was accessed
detokenizer.last_segment
# Contains all the tokens added so far
detokenizer.tokens
# Make sure that we detokenize any remaining tokens
detokenizer.finalize()
# Now detokenizer.text should match tokenizer.decode(detokenizer.tokens)
"""
__slots__ = ("text", "tokens", "offset")
def reset(self):
raise NotImplementedError()
def add_token(self, token):
raise NotImplementedError()
def finalize(self):
raise NotImplementedError()
@property
def last_segment(self):
"""Return the last segment of readable text since last time this property was accessed."""
text = self.text
segment = text[self.offset :]
self.offset = len(text)
return segment
class NaiveStreamingDetokenizer(StreamingDetokenizer):
"""NaiveStreamingDetokenizer relies on the underlying tokenizer
implementation and should work with every tokenizer.
Its complexity is O(T^2) where T is the longest line since it will
repeatedly detokenize the same tokens until a new line is generated.
"""
def __init__(self, tokenizer):
self._tokenizer = tokenizer
self._tokenizer.decode([0])
self.reset()
def reset(self):
self.offset = 0
self.tokens = []
self._text = ""
self._current_tokens = []
self._current_text = ""
def add_token(self, token):
self._current_tokens.append(token)
self.tokens.append(token)
def finalize(self):
self._text += self._tokenizer.decode(self._current_tokens)
self._current_tokens = []
self._current_text = ""
@property
def text(self):
if self._current_tokens:
self._current_text = self._tokenizer.decode(self._current_tokens)
if (
self._tokenizer.clean_up_tokenization_spaces
and self._current_text[-1] == " "
):
self._current_text = self._current_text[:-1]
if self._current_text and self._current_text[-1] == "\n":
self._text += self._current_text
self._current_tokens.clear()
self._current_text = ""
return self._text + self._current_text
class SPMStreamingDetokenizer(StreamingDetokenizer):
"""A streaming detokenizer for SPM models.
It adds tokens to the text if the next token starts with the special SPM
underscore which results in linear complexity.
"""
def __init__(self, tokenizer, trim_space=True):
self.trim_space = trim_space
self._sep = "\u2581".encode()
# Extract the tokens in a list from id to text
self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1)
for value, tokenid in tokenizer.vocab.items():
if value.startswith("<0x"):
# Replace bytes with their value
self.tokenmap[tokenid] = bytes([int(value[3:5], 16)])
else:
self.tokenmap[tokenid] = value.encode()
self.reset()
def reset(self):
self.offset = 0
self._unflushed = b""
self.text = ""
self.tokens = []
def _try_flush(self, force=False):
text = self._unflushed.replace(self._sep, b" ").decode("utf-8", "replace")
if not force and text.endswith("\ufffd"):
return
if not self.text and self.trim_space and text and text[0] == " ":
text = text[1:]
self.text += text
self._unflushed = b""
def add_token(self, token):
self.tokens.append(token)
v = self.tokenmap[token]
self._unflushed += v
self._try_flush()
def finalize(self):
self._try_flush(force=True)
self._unflushed = b""
class BPEStreamingDetokenizer(StreamingDetokenizer):
"""A streaming detokenizer for OpenAI style BPE models.
It adds tokens to the text if the next token starts with a space similar to
the SPM detokenizer.
"""
_byte_decoder = None
_space_matches = (".", "?", "!", ",", "n't", "'m", "'s", "'ve", "'re")
def __init__(self, tokenizer):
self.clean_spaces = tokenizer.clean_up_tokenization_spaces
# Extract the tokens in a list from id to text
self.tokenmap = [None] * len(tokenizer.vocab)
for value, tokenid in tokenizer.vocab.items():
self.tokenmap[tokenid] = value
self.reset()
# Make the BPE byte decoder from
# https://github.com/openai/gpt-2/blob/master/src/encoder.py
self.make_byte_decoder()
def reset(self):
self.offset = 0
self._unflushed = ""
self.text = ""
self.tokens = []
def _decode_bytes(self, seq):
barr = bytearray()
for c in seq:
res = self._byte_decoder.get(c, False)
if res:
barr.append(res)
else:
barr.extend(bytes(c, "utf-8"))
return barr.decode("utf-8", "replace")
def _maybe_trim_space(self, current_text):
if len(current_text) == 0:
return current_text
elif current_text[0] != " ":
return current_text
elif not self.text:
return current_text[1:]
elif self.clean_spaces and current_text[1:].startswith(self._space_matches):
return current_text[1:]
return current_text
def add_token(self, token):
self.tokens.append(token)
v = self.tokenmap[token]
self._unflushed += v
text = self._decode_bytes(self._unflushed)
# For multi-byte utf-8 wait until they are complete
# For single spaces wait until the next token to clean it if needed
if not text.endswith("\ufffd") and not (
len(v) == 1 and self._byte_decoder[v[0]] == 32
):
self.text += self._maybe_trim_space(text)
self._unflushed = ""
def finalize(self):
current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode(
"utf-8",
"replace",
)
self.text += self._maybe_trim_space(current_text)
self._unflushed = ""
@classmethod
def make_byte_decoder(cls):
"""See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale."""
if cls._byte_decoder is not None:
return
char_to_bytes = {}
limits = [
0,
ord("!"),
ord("~") + 1,
ord("¡"),
ord("¬") + 1,
ord("®"),
ord("ÿ") + 1,
]
n = 0
for i, (start, stop) in enumerate(zip(limits, limits[1:])):
if i % 2 == 0:
for b in range(start, stop):
char_to_bytes[chr(2**8 + n)] = b
n += 1
else:
for b in range(start, stop):
char_to_bytes[chr(b)] = b
cls._byte_decoder = char_to_bytes
class TokenizerWrapper:
"""A wrapper that combines an HF tokenizer and a detokenizer.
Accessing any attribute other than the ``detokenizer`` is forwarded to the
huggingface tokenizer.
"""
def __init__(
self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer, eos_token_ids=None
):
self._tokenizer = tokenizer
self._detokenizer = detokenizer_class(tokenizer)
self._eos_token_ids = (
set(eos_token_ids)
if eos_token_ids is not None
else {tokenizer.eos_token_id}
)
def add_eos_token(self, token: str):
token_id = None
try:
token_id = int(token)
except ValueError:
token_id = self._tokenizer.convert_tokens_to_ids(token)
if token_id is None:
raise ValueError(f"'{token}' is not a token for this tokenizer")
self._eos_token_ids.add(token_id)
def __getattr__(self, attr):
if attr == "detokenizer":
return self._detokenizer
elif attr == "eos_token_ids":
return self._eos_token_ids
elif attr.startswith("_"):
return self.__getattribute__(attr)
else:
return getattr(self._tokenizer, attr)
def __setattr__(self, attr, value):
if attr in {"detokenizer", "eos_token_ids"}:
if attr == "detokenizer":
raise AttributeError("Cannot set the detokenizer.")
elif attr == "eos_token_ids":
self._eos_token_ids = set(value) if value is not None else set()
elif attr.startswith("_"):
super().__setattr__(attr, value)
else:
setattr(self._tokenizer, attr, value)
def _match(a, b):
if type(a) != type(b):
return False
if isinstance(a, dict):
return len(a) == len(b) and all(k in b and _match(a[k], b[k]) for k in a)
if isinstance(a, list):
return len(a) == len(b) and all(_match(ai, bi) for ai, bi in zip(a, b))
return a == b
def _is_spm_decoder(decoder):
_target_description = {
"type": "Sequence",
"decoders": [
{"type": "Replace", "pattern": {"String": ""}, "content": " "},
{"type": "ByteFallback"},
{"type": "Fuse"},
{"type": "Strip", "content": " ", "start": 1, "stop": 0},
],
}
return _match(_target_description, decoder)
def _is_spm_decoder_no_space(decoder):
_target_description = {
"type": "Sequence",
"decoders": [
{"type": "Replace", "pattern": {"String": ""}, "content": " "},
{"type": "ByteFallback"},
{"type": "Fuse"},
],
}
return _match(_target_description, decoder)
def _is_bpe_decoder(decoder):
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None):
"""Load a huggingface tokenizer and try to infer the type of streaming
detokenizer to use.
Note, to use a fast streaming tokenizer, pass a local file path rather than
a Hugging Face repo ID.
"""
detokenizer_class = NaiveStreamingDetokenizer
tokenizer_file = model_path / "tokenizer.json"
if tokenizer_file.exists():
with open(tokenizer_file, "r") as fid:
tokenizer_content = json.load(fid)
if "decoder" in tokenizer_content:
if _is_spm_decoder(tokenizer_content["decoder"]):
detokenizer_class = SPMStreamingDetokenizer
elif _is_spm_decoder_no_space(tokenizer_content["decoder"]):
detokenizer_class = partial(SPMStreamingDetokenizer, trim_space=False)
elif _is_bpe_decoder(tokenizer_content["decoder"]):
detokenizer_class = BPEStreamingDetokenizer
if isinstance(eos_token_ids, int):
eos_token_ids = [eos_token_ids]
return TokenizerWrapper(
AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
detokenizer_class,
eos_token_ids=eos_token_ids,
)