Add streaming detokenizers (#651)

This commit is contained in:
Angelos Katharopoulos 2024-04-08 22:36:01 -07:00 committed by GitHub
parent c68aa3c7c3
commit 1278994b56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 330 additions and 27 deletions

View File

@ -0,0 +1,311 @@
import json
from transformers import AutoTokenizer
REPLACEMENT_CHAR = "\ufffd"
def _remove_space(x):
if x and x[0] == " ":
return x[1:]
return x
class StreamingDetokenizer:
"""The streaming detokenizer interface so that we can detokenize one token at a time.
Example usage is as follows:
detokenizer = ...
# Reset the tokenizer state
detokenizer.reset()
for token in generate(...):
detokenizer.add_token(token.item())
# Contains the whole text so far. Some tokens may not be included
# since it contains whole words usually.
detokenizer.text
# Contains the printable segment (usually a word) since the last
# time it was accessed
detokenizer.last_segment
# Contains all the tokens added so far
detokenizer.tokens
# Make sure that we detokenize any remaining tokens
detokenizer.finalize()
# Now detokenizer.text should match tokenizer.decode(detokenizer.tokens)
"""
__slots__ = ("text", "tokens", "offset")
def reset(self):
raise NotImplementedError()
def add_token(self, token):
raise NotImplementedError()
def finalize(self):
raise NotImplementedError()
@property
def last_segment(self):
"""Return the last segment of readable text since last time this property was accessed."""
text = self.text
if text and text[-1] != REPLACEMENT_CHAR:
segment = text[self.offset :]
self.offset = len(text)
return segment
return ""
class NaiveStreamingDetokenizer(StreamingDetokenizer):
"""NaiveStreamingDetokenizer relies on the underlying tokenizer
implementation and should work with every tokenizer.
Its complexity is O(T^2) where T is the longest line since it will
repeatedly detokenize the same tokens until a new line is generated.
"""
def __init__(self, tokenizer):
self._tokenizer = tokenizer
self.reset()
def reset(self):
self.offset = 0
self._tokens = []
self._text = ""
self._current_tokens = []
self._current_text = ""
def add_token(self, token):
self._current_tokens.append(token)
def finalize(self):
self._tokens.extend(self._current_tokens)
self._text += self._tokenizer.decode(self._current_tokens)
self._current_tokens = []
self._current_text = ""
@property
def text(self):
if self._current_tokens:
self._current_text = self._tokenizer.decode(self._current_tokens)
if self._current_text and self._current_text[-1] == "\n":
self._tokens.extend(self._current_tokens)
self._text += self._current_text
self._current_tokens.clear()
self._current_text = ""
return self._text + self._current_text
@property
def tokens(self):
return self._tokens
class SPMStreamingDetokenizer(StreamingDetokenizer):
"""A streaming detokenizer for SPM models.
It adds tokens to the text if the next token starts with the special SPM
underscore which results in linear complexity.
"""
def __init__(self, tokenizer):
# Extract the tokens in a list from id to text
self.tokenmap = [None] * len(tokenizer.vocab)
for value, tokenid in tokenizer.vocab.items():
self.tokenmap[tokenid] = value
# Replace bytes with their value
for i in range(len(self.tokenmap)):
if self.tokenmap[i].startswith("<0x"):
self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16))
self.reset()
def reset(self):
self.offset = 0
self._unflushed = ""
self.text = ""
self.tokens = []
def add_token(self, token):
v = self.tokenmap[token]
if v[0] == "\u2581":
if self.text:
self.text += self._unflushed.replace("\u2581", " ")
else:
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
self._unflushed = v
else:
self._unflushed += v
def finalize(self):
if self.text:
self.text += self._unflushed.replace("\u2581", " ")
else:
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
self._unflushed = ""
class BPEStreamingDetokenizer(StreamingDetokenizer):
"""A streaming detokenizer for OpenAI style BPE models.
It adds tokens to the text if the next token starts with a space similar to
the SPM detokenizer.
"""
_byte_decoder = None
def __init__(self, tokenizer, trim_space=False):
self.trim_space = trim_space
# Extract the tokens in a list from id to text
self.tokenmap = [None] * len(tokenizer.vocab)
for value, tokenid in tokenizer.vocab.items():
self.tokenmap[tokenid] = value
self.reset()
# Make the BPE byte decoder from
# https://github.com/openai/gpt-2/blob/master/src/encoder.py
self.make_byte_decoder()
def reset(self):
self.offset = 0
self._unflushed = ""
self.text = ""
self.tokens = []
def add_token(self, token):
v = self.tokenmap[token]
# if the token starts with space
if self._byte_decoder[v[0]] == 32:
current_text = bytearray(
self._byte_decoder[c] for c in self._unflushed
).decode("utf-8")
if self.text or not self.trim_space:
self.text += current_text
else:
self.text += _remove_space(current_text)
self._unflushed = v
else:
self._unflushed += v
def finalize(self):
current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode(
"utf-8"
)
if self.text or not self.trim_space:
self.text += current_text
else:
self.text += _remove_space(current_text)
self._unflushed = ""
@classmethod
def make_byte_decoder(cls):
"""See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale."""
if cls._byte_decoder is not None:
return
char_to_bytes = {}
limits = [
0,
ord("!"),
ord("~") + 1,
ord("¡"),
ord("¬") + 1,
ord("®"),
ord("ÿ") + 1,
]
n = 0
for i, (start, stop) in enumerate(zip(limits, limits[1:])):
if i % 2 == 0:
for b in range(start, stop):
char_to_bytes[chr(2**8 + n)] = b
n += 1
else:
for b in range(start, stop):
char_to_bytes[chr(b)] = b
cls._byte_decoder = char_to_bytes
class TokenizerWrapper:
"""A wrapper that combines an HF tokenizer and a detokenizer.
Accessing any attribute other than the ``detokenizer`` is forwarded to the
huggingface tokenizer.
"""
def __init__(self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer):
self._tokenizer = tokenizer
self._detokenizer = detokenizer_class(tokenizer)
def __getattr__(self, attr):
if attr == "detokenizer":
return self._detokenizer
else:
return getattr(self._tokenizer, attr)
def _match(a, b):
if type(a) != type(b):
return False
if isinstance(a, dict):
return len(a) == len(b) and all(k in b and _match(a[k], b[k]) for k in a)
if isinstance(a, list):
return len(a) == len(b) and all(_match(ai, bi) for ai, bi in zip(a, b))
return a == b
def _is_spm_decoder(decoder):
_target_description = {
"type": "Sequence",
"decoders": [
{"type": "Replace", "pattern": {"String": ""}, "content": " "},
{"type": "ByteFallback"},
{"type": "Fuse"},
{"type": "Strip", "content": " ", "start": 1, "stop": 0},
],
}
return _match(_target_description, decoder)
def _is_bpe_decoder(decoder):
_target_description = {
"type": "ByteLevel",
"add_prefix_space": False,
"trim_offsets": False,
"use_regex": False,
}
return _match(_target_description, decoder)
def load_tokenizer(model_path, tokenizer_config_extra={}):
"""Load a huggingface tokenizer and try to infer the type of streaming
detokenizer to use.
Note, to use a fast streaming tokenizer, pass a local file path rather than
a Hugging Face repo ID.
"""
detokenizer_class = NaiveStreamingDetokenizer
tokenizer_file = model_path / "tokenizer.json"
if tokenizer_file.exists():
tokenizer_content = json.load(tokenizer_file.open())
if "decoder" in tokenizer_content:
if _is_spm_decoder(tokenizer_content["decoder"]):
detokenizer_class = SPMStreamingDetokenizer
elif _is_bpe_decoder(tokenizer_content["decoder"]):
detokenizer_class = BPEStreamingDetokenizer
return TokenizerWrapper(
AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
detokenizer_class,
)

View File

@ -17,9 +17,9 @@ from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
from .sample_utils import top_p_sampling
# Local imports
from .sample_utils import top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import apply_lora_layers
from .tuner.utils import dequantize as dequantize_model
@ -189,7 +189,7 @@ def generate_step(
def generate(
model: nn.Module,
tokenizer: PreTrainedTokenizer,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
temp: float = 0.0,
max_tokens: int = 100,
@ -215,18 +215,18 @@ def generate(
repetition_penalty (float, optional): The penalty factor for repeating tokens.
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty.
"""
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
if verbose:
print("=" * 10)
print("Prompt:", prompt)
prompt_tokens = mx.array(tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer
tic = time.perf_counter()
tokens = []
token_strings = []
skip = 0
REPLACEMENT_CHAR = "\ufffd"
detokenizer.reset()
for (token, prob), n in zip(
generate_step(
@ -245,29 +245,21 @@ def generate(
tic = time.perf_counter()
if token == tokenizer.eos_token_id:
break
tokens.append(token)
detokenizer.add_token(token)
if verbose:
s = tokenizer.decode(tokens)
if not s:
continue
elif formatter:
formatter(s[skip:], prob.item())
skip = len(s)
elif s[-1] != REPLACEMENT_CHAR:
print(s[skip:], end="", flush=True)
skip = len(s)
# Reset token cache at line break
if s[-1] == "\n":
tokens = []
token_strings.append(s)
skip = 0
if formatter:
# We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize()
formatter(detokenizer.last_segment, prob.item())
else:
print(detokenizer.last_segment, end="", flush=True)
token_count = n + 1
token_strings.append(tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, ""))
detokenizer.finalize()
if verbose:
print(token_strings[-1][skip:], flush=True)
print(detokenizer.last_segment, flush=True)
gen_time = time.perf_counter() - tic
print("=" * 10)
if token_count == 0:
@ -278,7 +270,7 @@ def generate(
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
return "".join(token_strings)
return detokenizer.text
def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
@ -384,8 +376,8 @@ def load(
if adapter_path is not None:
model = apply_lora_layers(model, adapter_path)
model.eval()
tokenizer = load_tokenizer(model_path, tokenizer_config)
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
return model, tokenizer
@ -394,7 +386,7 @@ def fetch_from_hub(
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
model = load_model(model_path, lazy)
config = AutoConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer = load_tokenizer(model_path)
return model, config.to_dict(), tokenizer