From e55a9e8cb48cd3e227e85bc306a0793c404032c9 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 15 Apr 2024 14:15:25 -0700 Subject: [PATCH] Add an SPM detokenizer that doesn't trim initial space (#681) --- llms/mlx_lm/tokenizer_utils.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 50b87773..15b9963e 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -1,4 +1,5 @@ import json +from functools import partial from transformers import AutoTokenizer @@ -114,7 +115,9 @@ class SPMStreamingDetokenizer(StreamingDetokenizer): underscore which results in linear complexity. """ - def __init__(self, tokenizer): + def __init__(self, tokenizer, trim_space=True): + self.trim_space = trim_space + # Extract the tokens in a list from id to text self.tokenmap = [None] * len(tokenizer.vocab) for value, tokenid in tokenizer.vocab.items(): @@ -136,7 +139,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer): def add_token(self, token): v = self.tokenmap[token] if v[0] == "\u2581": - if self.text: + if self.text or not self.trim_space: self.text += self._unflushed.replace("\u2581", " ") else: self.text = _remove_space(self._unflushed.replace("\u2581", " ")) @@ -145,7 +148,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer): self._unflushed += v def finalize(self): - if self.text: + if self.text or not self.trim_space: self.text += self._unflushed.replace("\u2581", " ") else: self.text = _remove_space(self._unflushed.replace("\u2581", " ")) @@ -276,6 +279,18 @@ def _is_spm_decoder(decoder): return _match(_target_description, decoder) +def _is_spm_decoder_no_space(decoder): + _target_description = { + "type": "Sequence", + "decoders": [ + {"type": "Replace", "pattern": {"String": "▁"}, "content": " "}, + {"type": "ByteFallback"}, + {"type": "Fuse"}, + ], + } + return _match(_target_description, decoder) + + def _is_bpe_decoder(decoder): _target_description = { "type": "ByteLevel", @@ -302,6 +317,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}): if "decoder" in tokenizer_content: if _is_spm_decoder(tokenizer_content["decoder"]): detokenizer_class = SPMStreamingDetokenizer + elif _is_spm_decoder_no_space(tokenizer_content["decoder"]): + detokenizer_class = partial(SPMStreamingDetokenizer, trim_space=False) elif _is_bpe_decoder(tokenizer_content["decoder"]): detokenizer_class = BPEStreamingDetokenizer