Add an SPM detokenizer that doesn't trim initial space (#681)

This commit is contained in:
Angelos Katharopoulos 2024-04-15 14:15:25 -07:00 committed by GitHub
parent d3f8e4aee9
commit e55a9e8cb4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,4 +1,5 @@
import json
from functools import partial
from transformers import AutoTokenizer
@ -114,7 +115,9 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
underscore which results in linear complexity.
"""
def __init__(self, tokenizer):
def __init__(self, tokenizer, trim_space=True):
self.trim_space = trim_space
# Extract the tokens in a list from id to text
self.tokenmap = [None] * len(tokenizer.vocab)
for value, tokenid in tokenizer.vocab.items():
@ -136,7 +139,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
def add_token(self, token):
v = self.tokenmap[token]
if v[0] == "\u2581":
if self.text:
if self.text or not self.trim_space:
self.text += self._unflushed.replace("\u2581", " ")
else:
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
@ -145,7 +148,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
self._unflushed += v
def finalize(self):
if self.text:
if self.text or not self.trim_space:
self.text += self._unflushed.replace("\u2581", " ")
else:
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
@ -276,6 +279,18 @@ def _is_spm_decoder(decoder):
return _match(_target_description, decoder)
def _is_spm_decoder_no_space(decoder):
_target_description = {
"type": "Sequence",
"decoders": [
{"type": "Replace", "pattern": {"String": ""}, "content": " "},
{"type": "ByteFallback"},
{"type": "Fuse"},
],
}
return _match(_target_description, decoder)
def _is_bpe_decoder(decoder):
_target_description = {
"type": "ByteLevel",
@ -302,6 +317,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}):
if "decoder" in tokenizer_content:
if _is_spm_decoder(tokenizer_content["decoder"]):
detokenizer_class = SPMStreamingDetokenizer
elif _is_spm_decoder_no_space(tokenizer_content["decoder"]):
detokenizer_class = partial(SPMStreamingDetokenizer, trim_space=False)
elif _is_bpe_decoder(tokenizer_content["decoder"]):
detokenizer_class = BPEStreamingDetokenizer