Add an SPM detokenizer that doesn't trim initial space (#681)

This commit is contained in:
Angelos Katharopoulos
2024-04-15 14:15:25 -07:00
committed by GitHub
parent d3f8e4aee9
commit e55a9e8cb4

View File

@@ -1,4 +1,5 @@
import json
from functools import partial
from transformers import AutoTokenizer
@@ -114,7 +115,9 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
underscore which results in linear complexity.
"""
def __init__(self, tokenizer):
def __init__(self, tokenizer, trim_space=True):
self.trim_space = trim_space
# Extract the tokens in a list from id to text
self.tokenmap = [None] * len(tokenizer.vocab)
for value, tokenid in tokenizer.vocab.items():
@@ -136,7 +139,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
def add_token(self, token):
v = self.tokenmap[token]
if v[0] == "\u2581":
if self.text:
if self.text or not self.trim_space:
self.text += self._unflushed.replace("\u2581", " ")
else:
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
@@ -145,7 +148,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
self._unflushed += v
def finalize(self):
if self.text:
if self.text or not self.trim_space:
self.text += self._unflushed.replace("\u2581", " ")
else:
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
@@ -276,6 +279,18 @@ def _is_spm_decoder(decoder):
return _match(_target_description, decoder)
def _is_spm_decoder_no_space(decoder):
_target_description = {
"type": "Sequence",
"decoders": [
{"type": "Replace", "pattern": {"String": ""}, "content": " "},
{"type": "ByteFallback"},
{"type": "Fuse"},
],
}
return _match(_target_description, decoder)
def _is_bpe_decoder(decoder):
_target_description = {
"type": "ByteLevel",
@@ -302,6 +317,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}):
if "decoder" in tokenizer_content:
if _is_spm_decoder(tokenizer_content["decoder"]):
detokenizer_class = SPMStreamingDetokenizer
elif _is_spm_decoder_no_space(tokenizer_content["decoder"]):
detokenizer_class = partial(SPMStreamingDetokenizer, trim_space=False)
elif _is_bpe_decoder(tokenizer_content["decoder"]):
detokenizer_class = BPEStreamingDetokenizer