mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Add an SPM detokenizer that doesn't trim initial space (#681)
This commit is contained in:
parent
d3f8e4aee9
commit
e55a9e8cb4
@ -1,4 +1,5 @@
|
||||
import json
|
||||
from functools import partial
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
@ -114,7 +115,9 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||
underscore which results in linear complexity.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
def __init__(self, tokenizer, trim_space=True):
|
||||
self.trim_space = trim_space
|
||||
|
||||
# Extract the tokens in a list from id to text
|
||||
self.tokenmap = [None] * len(tokenizer.vocab)
|
||||
for value, tokenid in tokenizer.vocab.items():
|
||||
@ -136,7 +139,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||
def add_token(self, token):
|
||||
v = self.tokenmap[token]
|
||||
if v[0] == "\u2581":
|
||||
if self.text:
|
||||
if self.text or not self.trim_space:
|
||||
self.text += self._unflushed.replace("\u2581", " ")
|
||||
else:
|
||||
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
|
||||
@ -145,7 +148,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||
self._unflushed += v
|
||||
|
||||
def finalize(self):
|
||||
if self.text:
|
||||
if self.text or not self.trim_space:
|
||||
self.text += self._unflushed.replace("\u2581", " ")
|
||||
else:
|
||||
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
|
||||
@ -276,6 +279,18 @@ def _is_spm_decoder(decoder):
|
||||
return _match(_target_description, decoder)
|
||||
|
||||
|
||||
def _is_spm_decoder_no_space(decoder):
|
||||
_target_description = {
|
||||
"type": "Sequence",
|
||||
"decoders": [
|
||||
{"type": "Replace", "pattern": {"String": "▁"}, "content": " "},
|
||||
{"type": "ByteFallback"},
|
||||
{"type": "Fuse"},
|
||||
],
|
||||
}
|
||||
return _match(_target_description, decoder)
|
||||
|
||||
|
||||
def _is_bpe_decoder(decoder):
|
||||
_target_description = {
|
||||
"type": "ByteLevel",
|
||||
@ -302,6 +317,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}):
|
||||
if "decoder" in tokenizer_content:
|
||||
if _is_spm_decoder(tokenizer_content["decoder"]):
|
||||
detokenizer_class = SPMStreamingDetokenizer
|
||||
elif _is_spm_decoder_no_space(tokenizer_content["decoder"]):
|
||||
detokenizer_class = partial(SPMStreamingDetokenizer, trim_space=False)
|
||||
elif _is_bpe_decoder(tokenizer_content["decoder"]):
|
||||
detokenizer_class = BPEStreamingDetokenizer
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user