mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Add an SPM detokenizer that doesn't trim initial space (#681)
This commit is contained in:
parent
d3f8e4aee9
commit
e55a9e8cb4
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
@ -114,7 +115,9 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
underscore which results in linear complexity.
|
underscore which results in linear complexity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tokenizer):
|
def __init__(self, tokenizer, trim_space=True):
|
||||||
|
self.trim_space = trim_space
|
||||||
|
|
||||||
# Extract the tokens in a list from id to text
|
# Extract the tokens in a list from id to text
|
||||||
self.tokenmap = [None] * len(tokenizer.vocab)
|
self.tokenmap = [None] * len(tokenizer.vocab)
|
||||||
for value, tokenid in tokenizer.vocab.items():
|
for value, tokenid in tokenizer.vocab.items():
|
||||||
@ -136,7 +139,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
def add_token(self, token):
|
def add_token(self, token):
|
||||||
v = self.tokenmap[token]
|
v = self.tokenmap[token]
|
||||||
if v[0] == "\u2581":
|
if v[0] == "\u2581":
|
||||||
if self.text:
|
if self.text or not self.trim_space:
|
||||||
self.text += self._unflushed.replace("\u2581", " ")
|
self.text += self._unflushed.replace("\u2581", " ")
|
||||||
else:
|
else:
|
||||||
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
|
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
|
||||||
@ -145,7 +148,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
self._unflushed += v
|
self._unflushed += v
|
||||||
|
|
||||||
def finalize(self):
|
def finalize(self):
|
||||||
if self.text:
|
if self.text or not self.trim_space:
|
||||||
self.text += self._unflushed.replace("\u2581", " ")
|
self.text += self._unflushed.replace("\u2581", " ")
|
||||||
else:
|
else:
|
||||||
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
|
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
|
||||||
@ -276,6 +279,18 @@ def _is_spm_decoder(decoder):
|
|||||||
return _match(_target_description, decoder)
|
return _match(_target_description, decoder)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_spm_decoder_no_space(decoder):
|
||||||
|
_target_description = {
|
||||||
|
"type": "Sequence",
|
||||||
|
"decoders": [
|
||||||
|
{"type": "Replace", "pattern": {"String": "▁"}, "content": " "},
|
||||||
|
{"type": "ByteFallback"},
|
||||||
|
{"type": "Fuse"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
return _match(_target_description, decoder)
|
||||||
|
|
||||||
|
|
||||||
def _is_bpe_decoder(decoder):
|
def _is_bpe_decoder(decoder):
|
||||||
_target_description = {
|
_target_description = {
|
||||||
"type": "ByteLevel",
|
"type": "ByteLevel",
|
||||||
@ -302,6 +317,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}):
|
|||||||
if "decoder" in tokenizer_content:
|
if "decoder" in tokenizer_content:
|
||||||
if _is_spm_decoder(tokenizer_content["decoder"]):
|
if _is_spm_decoder(tokenizer_content["decoder"]):
|
||||||
detokenizer_class = SPMStreamingDetokenizer
|
detokenizer_class = SPMStreamingDetokenizer
|
||||||
|
elif _is_spm_decoder_no_space(tokenizer_content["decoder"]):
|
||||||
|
detokenizer_class = partial(SPMStreamingDetokenizer, trim_space=False)
|
||||||
elif _is_bpe_decoder(tokenizer_content["decoder"]):
|
elif _is_bpe_decoder(tokenizer_content["decoder"]):
|
||||||
detokenizer_class = BPEStreamingDetokenizer
|
detokenizer_class = BPEStreamingDetokenizer
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user