From 1278994b56f769497892614d98248e4d71fab1c0 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 8 Apr 2024 22:36:01 -0700 Subject: [PATCH] Add streaming detokenizers (#651) --- llms/mlx_lm/tokenizer_utils.py | 311 +++++++++++++++++++++++++++++++++ llms/mlx_lm/utils.py | 46 ++--- 2 files changed, 330 insertions(+), 27 deletions(-) create mode 100644 llms/mlx_lm/tokenizer_utils.py diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py new file mode 100644 index 00000000..50b87773 --- /dev/null +++ b/llms/mlx_lm/tokenizer_utils.py @@ -0,0 +1,311 @@ +import json + +from transformers import AutoTokenizer + +REPLACEMENT_CHAR = "\ufffd" + + +def _remove_space(x): + if x and x[0] == " ": + return x[1:] + return x + + +class StreamingDetokenizer: + """The streaming detokenizer interface so that we can detokenize one token at a time. + + Example usage is as follows: + + detokenizer = ... + + # Reset the tokenizer state + detokenizer.reset() + + for token in generate(...): + detokenizer.add_token(token.item()) + + # Contains the whole text so far. Some tokens may not be included + # since it contains whole words usually. + detokenizer.text + + # Contains the printable segment (usually a word) since the last + # time it was accessed + detokenizer.last_segment + + # Contains all the tokens added so far + detokenizer.tokens + + # Make sure that we detokenize any remaining tokens + detokenizer.finalize() + + # Now detokenizer.text should match tokenizer.decode(detokenizer.tokens) + """ + + __slots__ = ("text", "tokens", "offset") + + def reset(self): + raise NotImplementedError() + + def add_token(self, token): + raise NotImplementedError() + + def finalize(self): + raise NotImplementedError() + + @property + def last_segment(self): + """Return the last segment of readable text since last time this property was accessed.""" + text = self.text + if text and text[-1] != REPLACEMENT_CHAR: + segment = text[self.offset :] + self.offset = len(text) + return segment + return "" + + +class NaiveStreamingDetokenizer(StreamingDetokenizer): + """NaiveStreamingDetokenizer relies on the underlying tokenizer + implementation and should work with every tokenizer. + + Its complexity is O(T^2) where T is the longest line since it will + repeatedly detokenize the same tokens until a new line is generated. + """ + + def __init__(self, tokenizer): + self._tokenizer = tokenizer + self.reset() + + def reset(self): + self.offset = 0 + self._tokens = [] + self._text = "" + self._current_tokens = [] + self._current_text = "" + + def add_token(self, token): + self._current_tokens.append(token) + + def finalize(self): + self._tokens.extend(self._current_tokens) + self._text += self._tokenizer.decode(self._current_tokens) + self._current_tokens = [] + self._current_text = "" + + @property + def text(self): + if self._current_tokens: + self._current_text = self._tokenizer.decode(self._current_tokens) + if self._current_text and self._current_text[-1] == "\n": + self._tokens.extend(self._current_tokens) + self._text += self._current_text + self._current_tokens.clear() + self._current_text = "" + return self._text + self._current_text + + @property + def tokens(self): + return self._tokens + + +class SPMStreamingDetokenizer(StreamingDetokenizer): + """A streaming detokenizer for SPM models. + + It adds tokens to the text if the next token starts with the special SPM + underscore which results in linear complexity. + """ + + def __init__(self, tokenizer): + # Extract the tokens in a list from id to text + self.tokenmap = [None] * len(tokenizer.vocab) + for value, tokenid in tokenizer.vocab.items(): + self.tokenmap[tokenid] = value + + # Replace bytes with their value + for i in range(len(self.tokenmap)): + if self.tokenmap[i].startswith("<0x"): + self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16)) + + self.reset() + + def reset(self): + self.offset = 0 + self._unflushed = "" + self.text = "" + self.tokens = [] + + def add_token(self, token): + v = self.tokenmap[token] + if v[0] == "\u2581": + if self.text: + self.text += self._unflushed.replace("\u2581", " ") + else: + self.text = _remove_space(self._unflushed.replace("\u2581", " ")) + self._unflushed = v + else: + self._unflushed += v + + def finalize(self): + if self.text: + self.text += self._unflushed.replace("\u2581", " ") + else: + self.text = _remove_space(self._unflushed.replace("\u2581", " ")) + self._unflushed = "" + + +class BPEStreamingDetokenizer(StreamingDetokenizer): + """A streaming detokenizer for OpenAI style BPE models. + + It adds tokens to the text if the next token starts with a space similar to + the SPM detokenizer. + """ + + _byte_decoder = None + + def __init__(self, tokenizer, trim_space=False): + self.trim_space = trim_space + + # Extract the tokens in a list from id to text + self.tokenmap = [None] * len(tokenizer.vocab) + for value, tokenid in tokenizer.vocab.items(): + self.tokenmap[tokenid] = value + + self.reset() + + # Make the BPE byte decoder from + # https://github.com/openai/gpt-2/blob/master/src/encoder.py + self.make_byte_decoder() + + def reset(self): + self.offset = 0 + self._unflushed = "" + self.text = "" + self.tokens = [] + + def add_token(self, token): + v = self.tokenmap[token] + # if the token starts with space + if self._byte_decoder[v[0]] == 32: + current_text = bytearray( + self._byte_decoder[c] for c in self._unflushed + ).decode("utf-8") + if self.text or not self.trim_space: + self.text += current_text + else: + self.text += _remove_space(current_text) + self._unflushed = v + else: + self._unflushed += v + + def finalize(self): + current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode( + "utf-8" + ) + if self.text or not self.trim_space: + self.text += current_text + else: + self.text += _remove_space(current_text) + self._unflushed = "" + + @classmethod + def make_byte_decoder(cls): + """See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale.""" + if cls._byte_decoder is not None: + return + + char_to_bytes = {} + limits = [ + 0, + ord("!"), + ord("~") + 1, + ord("¡"), + ord("¬") + 1, + ord("®"), + ord("ÿ") + 1, + ] + n = 0 + for i, (start, stop) in enumerate(zip(limits, limits[1:])): + if i % 2 == 0: + for b in range(start, stop): + char_to_bytes[chr(2**8 + n)] = b + n += 1 + else: + for b in range(start, stop): + char_to_bytes[chr(b)] = b + cls._byte_decoder = char_to_bytes + + +class TokenizerWrapper: + """A wrapper that combines an HF tokenizer and a detokenizer. + + Accessing any attribute other than the ``detokenizer`` is forwarded to the + huggingface tokenizer. + """ + + def __init__(self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer): + self._tokenizer = tokenizer + self._detokenizer = detokenizer_class(tokenizer) + + def __getattr__(self, attr): + if attr == "detokenizer": + return self._detokenizer + else: + return getattr(self._tokenizer, attr) + + +def _match(a, b): + if type(a) != type(b): + return False + if isinstance(a, dict): + return len(a) == len(b) and all(k in b and _match(a[k], b[k]) for k in a) + if isinstance(a, list): + return len(a) == len(b) and all(_match(ai, bi) for ai, bi in zip(a, b)) + + return a == b + + +def _is_spm_decoder(decoder): + _target_description = { + "type": "Sequence", + "decoders": [ + {"type": "Replace", "pattern": {"String": "▁"}, "content": " "}, + {"type": "ByteFallback"}, + {"type": "Fuse"}, + {"type": "Strip", "content": " ", "start": 1, "stop": 0}, + ], + } + return _match(_target_description, decoder) + + +def _is_bpe_decoder(decoder): + _target_description = { + "type": "ByteLevel", + "add_prefix_space": False, + "trim_offsets": False, + "use_regex": False, + } + + return _match(_target_description, decoder) + + +def load_tokenizer(model_path, tokenizer_config_extra={}): + """Load a huggingface tokenizer and try to infer the type of streaming + detokenizer to use. + + Note, to use a fast streaming tokenizer, pass a local file path rather than + a Hugging Face repo ID. + """ + detokenizer_class = NaiveStreamingDetokenizer + + tokenizer_file = model_path / "tokenizer.json" + if tokenizer_file.exists(): + tokenizer_content = json.load(tokenizer_file.open()) + if "decoder" in tokenizer_content: + if _is_spm_decoder(tokenizer_content["decoder"]): + detokenizer_class = SPMStreamingDetokenizer + elif _is_bpe_decoder(tokenizer_content["decoder"]): + detokenizer_class = BPEStreamingDetokenizer + + return TokenizerWrapper( + AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra), + detokenizer_class, + ) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index ea0f7cfe..45453111 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -17,9 +17,9 @@ from huggingface_hub import snapshot_download from mlx.utils import tree_flatten from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer -from .sample_utils import top_p_sampling - # Local imports +from .sample_utils import top_p_sampling +from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import apply_lora_layers from .tuner.utils import dequantize as dequantize_model @@ -189,7 +189,7 @@ def generate_step( def generate( model: nn.Module, - tokenizer: PreTrainedTokenizer, + tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], prompt: str, temp: float = 0.0, max_tokens: int = 100, @@ -215,18 +215,18 @@ def generate( repetition_penalty (float, optional): The penalty factor for repeating tokens. repetition_context_size (int, optional): The number of tokens to consider for repetition penalty. """ + if not isinstance(tokenizer, TokenizerWrapper): + tokenizer = TokenizerWrapper(tokenizer) if verbose: print("=" * 10) print("Prompt:", prompt) prompt_tokens = mx.array(tokenizer.encode(prompt)) + detokenizer = tokenizer.detokenizer tic = time.perf_counter() - tokens = [] - token_strings = [] - skip = 0 - REPLACEMENT_CHAR = "\ufffd" + detokenizer.reset() for (token, prob), n in zip( generate_step( @@ -245,29 +245,21 @@ def generate( tic = time.perf_counter() if token == tokenizer.eos_token_id: break - tokens.append(token) + detokenizer.add_token(token) if verbose: - s = tokenizer.decode(tokens) - if not s: - continue - elif formatter: - formatter(s[skip:], prob.item()) - skip = len(s) - elif s[-1] != REPLACEMENT_CHAR: - print(s[skip:], end="", flush=True) - skip = len(s) - # Reset token cache at line break - if s[-1] == "\n": - tokens = [] - token_strings.append(s) - skip = 0 + if formatter: + # We have to finalize so that the prob corresponds to the last segment + detokenizer.finalize() + formatter(detokenizer.last_segment, prob.item()) + else: + print(detokenizer.last_segment, end="", flush=True) token_count = n + 1 - token_strings.append(tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")) + detokenizer.finalize() if verbose: - print(token_strings[-1][skip:], flush=True) + print(detokenizer.last_segment, flush=True) gen_time = time.perf_counter() - tic print("=" * 10) if token_count == 0: @@ -278,7 +270,7 @@ def generate( print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") print(f"Generation: {gen_tps:.3f} tokens-per-sec") - return "".join(token_strings) + return detokenizer.text def load_model(model_path: Path, lazy: bool = False) -> nn.Module: @@ -384,8 +376,8 @@ def load( if adapter_path is not None: model = apply_lora_layers(model, adapter_path) model.eval() + tokenizer = load_tokenizer(model_path, tokenizer_config) - tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config) return model, tokenizer @@ -394,7 +386,7 @@ def fetch_from_hub( ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: model = load_model(model_path, lazy) config = AutoConfig.from_pretrained(model_path) - tokenizer = AutoTokenizer.from_pretrained(model_path) + tokenizer = load_tokenizer(model_path) return model, config.to_dict(), tokenizer