mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 09:51:19 +08:00
Add streaming detokenizers (#651)
This commit is contained in:
parent
c68aa3c7c3
commit
1278994b56
311
llms/mlx_lm/tokenizer_utils.py
Normal file
311
llms/mlx_lm/tokenizer_utils.py
Normal file
@ -0,0 +1,311 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
REPLACEMENT_CHAR = "\ufffd"
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_space(x):
|
||||||
|
if x and x[0] == " ":
|
||||||
|
return x[1:]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingDetokenizer:
|
||||||
|
"""The streaming detokenizer interface so that we can detokenize one token at a time.
|
||||||
|
|
||||||
|
Example usage is as follows:
|
||||||
|
|
||||||
|
detokenizer = ...
|
||||||
|
|
||||||
|
# Reset the tokenizer state
|
||||||
|
detokenizer.reset()
|
||||||
|
|
||||||
|
for token in generate(...):
|
||||||
|
detokenizer.add_token(token.item())
|
||||||
|
|
||||||
|
# Contains the whole text so far. Some tokens may not be included
|
||||||
|
# since it contains whole words usually.
|
||||||
|
detokenizer.text
|
||||||
|
|
||||||
|
# Contains the printable segment (usually a word) since the last
|
||||||
|
# time it was accessed
|
||||||
|
detokenizer.last_segment
|
||||||
|
|
||||||
|
# Contains all the tokens added so far
|
||||||
|
detokenizer.tokens
|
||||||
|
|
||||||
|
# Make sure that we detokenize any remaining tokens
|
||||||
|
detokenizer.finalize()
|
||||||
|
|
||||||
|
# Now detokenizer.text should match tokenizer.decode(detokenizer.tokens)
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ("text", "tokens", "offset")
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def add_token(self, token):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_segment(self):
|
||||||
|
"""Return the last segment of readable text since last time this property was accessed."""
|
||||||
|
text = self.text
|
||||||
|
if text and text[-1] != REPLACEMENT_CHAR:
|
||||||
|
segment = text[self.offset :]
|
||||||
|
self.offset = len(text)
|
||||||
|
return segment
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||||
|
"""NaiveStreamingDetokenizer relies on the underlying tokenizer
|
||||||
|
implementation and should work with every tokenizer.
|
||||||
|
|
||||||
|
Its complexity is O(T^2) where T is the longest line since it will
|
||||||
|
repeatedly detokenize the same tokens until a new line is generated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer):
|
||||||
|
self._tokenizer = tokenizer
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.offset = 0
|
||||||
|
self._tokens = []
|
||||||
|
self._text = ""
|
||||||
|
self._current_tokens = []
|
||||||
|
self._current_text = ""
|
||||||
|
|
||||||
|
def add_token(self, token):
|
||||||
|
self._current_tokens.append(token)
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
self._tokens.extend(self._current_tokens)
|
||||||
|
self._text += self._tokenizer.decode(self._current_tokens)
|
||||||
|
self._current_tokens = []
|
||||||
|
self._current_text = ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self):
|
||||||
|
if self._current_tokens:
|
||||||
|
self._current_text = self._tokenizer.decode(self._current_tokens)
|
||||||
|
if self._current_text and self._current_text[-1] == "\n":
|
||||||
|
self._tokens.extend(self._current_tokens)
|
||||||
|
self._text += self._current_text
|
||||||
|
self._current_tokens.clear()
|
||||||
|
self._current_text = ""
|
||||||
|
return self._text + self._current_text
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tokens(self):
|
||||||
|
return self._tokens
|
||||||
|
|
||||||
|
|
||||||
|
class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||||
|
"""A streaming detokenizer for SPM models.
|
||||||
|
|
||||||
|
It adds tokens to the text if the next token starts with the special SPM
|
||||||
|
underscore which results in linear complexity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer):
|
||||||
|
# Extract the tokens in a list from id to text
|
||||||
|
self.tokenmap = [None] * len(tokenizer.vocab)
|
||||||
|
for value, tokenid in tokenizer.vocab.items():
|
||||||
|
self.tokenmap[tokenid] = value
|
||||||
|
|
||||||
|
# Replace bytes with their value
|
||||||
|
for i in range(len(self.tokenmap)):
|
||||||
|
if self.tokenmap[i].startswith("<0x"):
|
||||||
|
self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16))
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.offset = 0
|
||||||
|
self._unflushed = ""
|
||||||
|
self.text = ""
|
||||||
|
self.tokens = []
|
||||||
|
|
||||||
|
def add_token(self, token):
|
||||||
|
v = self.tokenmap[token]
|
||||||
|
if v[0] == "\u2581":
|
||||||
|
if self.text:
|
||||||
|
self.text += self._unflushed.replace("\u2581", " ")
|
||||||
|
else:
|
||||||
|
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
|
||||||
|
self._unflushed = v
|
||||||
|
else:
|
||||||
|
self._unflushed += v
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
if self.text:
|
||||||
|
self.text += self._unflushed.replace("\u2581", " ")
|
||||||
|
else:
|
||||||
|
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
|
||||||
|
self._unflushed = ""
|
||||||
|
|
||||||
|
|
||||||
|
class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||||
|
"""A streaming detokenizer for OpenAI style BPE models.
|
||||||
|
|
||||||
|
It adds tokens to the text if the next token starts with a space similar to
|
||||||
|
the SPM detokenizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_byte_decoder = None
|
||||||
|
|
||||||
|
def __init__(self, tokenizer, trim_space=False):
|
||||||
|
self.trim_space = trim_space
|
||||||
|
|
||||||
|
# Extract the tokens in a list from id to text
|
||||||
|
self.tokenmap = [None] * len(tokenizer.vocab)
|
||||||
|
for value, tokenid in tokenizer.vocab.items():
|
||||||
|
self.tokenmap[tokenid] = value
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
# Make the BPE byte decoder from
|
||||||
|
# https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
||||||
|
self.make_byte_decoder()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.offset = 0
|
||||||
|
self._unflushed = ""
|
||||||
|
self.text = ""
|
||||||
|
self.tokens = []
|
||||||
|
|
||||||
|
def add_token(self, token):
|
||||||
|
v = self.tokenmap[token]
|
||||||
|
# if the token starts with space
|
||||||
|
if self._byte_decoder[v[0]] == 32:
|
||||||
|
current_text = bytearray(
|
||||||
|
self._byte_decoder[c] for c in self._unflushed
|
||||||
|
).decode("utf-8")
|
||||||
|
if self.text or not self.trim_space:
|
||||||
|
self.text += current_text
|
||||||
|
else:
|
||||||
|
self.text += _remove_space(current_text)
|
||||||
|
self._unflushed = v
|
||||||
|
else:
|
||||||
|
self._unflushed += v
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
if self.text or not self.trim_space:
|
||||||
|
self.text += current_text
|
||||||
|
else:
|
||||||
|
self.text += _remove_space(current_text)
|
||||||
|
self._unflushed = ""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_byte_decoder(cls):
|
||||||
|
"""See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale."""
|
||||||
|
if cls._byte_decoder is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
char_to_bytes = {}
|
||||||
|
limits = [
|
||||||
|
0,
|
||||||
|
ord("!"),
|
||||||
|
ord("~") + 1,
|
||||||
|
ord("¡"),
|
||||||
|
ord("¬") + 1,
|
||||||
|
ord("®"),
|
||||||
|
ord("ÿ") + 1,
|
||||||
|
]
|
||||||
|
n = 0
|
||||||
|
for i, (start, stop) in enumerate(zip(limits, limits[1:])):
|
||||||
|
if i % 2 == 0:
|
||||||
|
for b in range(start, stop):
|
||||||
|
char_to_bytes[chr(2**8 + n)] = b
|
||||||
|
n += 1
|
||||||
|
else:
|
||||||
|
for b in range(start, stop):
|
||||||
|
char_to_bytes[chr(b)] = b
|
||||||
|
cls._byte_decoder = char_to_bytes
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerWrapper:
|
||||||
|
"""A wrapper that combines an HF tokenizer and a detokenizer.
|
||||||
|
|
||||||
|
Accessing any attribute other than the ``detokenizer`` is forwarded to the
|
||||||
|
huggingface tokenizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer):
|
||||||
|
self._tokenizer = tokenizer
|
||||||
|
self._detokenizer = detokenizer_class(tokenizer)
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
if attr == "detokenizer":
|
||||||
|
return self._detokenizer
|
||||||
|
else:
|
||||||
|
return getattr(self._tokenizer, attr)
|
||||||
|
|
||||||
|
|
||||||
|
def _match(a, b):
|
||||||
|
if type(a) != type(b):
|
||||||
|
return False
|
||||||
|
if isinstance(a, dict):
|
||||||
|
return len(a) == len(b) and all(k in b and _match(a[k], b[k]) for k in a)
|
||||||
|
if isinstance(a, list):
|
||||||
|
return len(a) == len(b) and all(_match(ai, bi) for ai, bi in zip(a, b))
|
||||||
|
|
||||||
|
return a == b
|
||||||
|
|
||||||
|
|
||||||
|
def _is_spm_decoder(decoder):
|
||||||
|
_target_description = {
|
||||||
|
"type": "Sequence",
|
||||||
|
"decoders": [
|
||||||
|
{"type": "Replace", "pattern": {"String": "▁"}, "content": " "},
|
||||||
|
{"type": "ByteFallback"},
|
||||||
|
{"type": "Fuse"},
|
||||||
|
{"type": "Strip", "content": " ", "start": 1, "stop": 0},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
return _match(_target_description, decoder)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_bpe_decoder(decoder):
|
||||||
|
_target_description = {
|
||||||
|
"type": "ByteLevel",
|
||||||
|
"add_prefix_space": False,
|
||||||
|
"trim_offsets": False,
|
||||||
|
"use_regex": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
return _match(_target_description, decoder)
|
||||||
|
|
||||||
|
|
||||||
|
def load_tokenizer(model_path, tokenizer_config_extra={}):
|
||||||
|
"""Load a huggingface tokenizer and try to infer the type of streaming
|
||||||
|
detokenizer to use.
|
||||||
|
|
||||||
|
Note, to use a fast streaming tokenizer, pass a local file path rather than
|
||||||
|
a Hugging Face repo ID.
|
||||||
|
"""
|
||||||
|
detokenizer_class = NaiveStreamingDetokenizer
|
||||||
|
|
||||||
|
tokenizer_file = model_path / "tokenizer.json"
|
||||||
|
if tokenizer_file.exists():
|
||||||
|
tokenizer_content = json.load(tokenizer_file.open())
|
||||||
|
if "decoder" in tokenizer_content:
|
||||||
|
if _is_spm_decoder(tokenizer_content["decoder"]):
|
||||||
|
detokenizer_class = SPMStreamingDetokenizer
|
||||||
|
elif _is_bpe_decoder(tokenizer_content["decoder"]):
|
||||||
|
detokenizer_class = BPEStreamingDetokenizer
|
||||||
|
|
||||||
|
return TokenizerWrapper(
|
||||||
|
AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
|
||||||
|
detokenizer_class,
|
||||||
|
)
|
@ -17,9 +17,9 @@ from huggingface_hub import snapshot_download
|
|||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
||||||
|
|
||||||
from .sample_utils import top_p_sampling
|
|
||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
|
from .sample_utils import top_p_sampling
|
||||||
|
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||||
from .tuner.utils import apply_lora_layers
|
from .tuner.utils import apply_lora_layers
|
||||||
from .tuner.utils import dequantize as dequantize_model
|
from .tuner.utils import dequantize as dequantize_model
|
||||||
|
|
||||||
@ -189,7 +189,7 @@ def generate_step(
|
|||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||||
prompt: str,
|
prompt: str,
|
||||||
temp: float = 0.0,
|
temp: float = 0.0,
|
||||||
max_tokens: int = 100,
|
max_tokens: int = 100,
|
||||||
@ -215,18 +215,18 @@ def generate(
|
|||||||
repetition_penalty (float, optional): The penalty factor for repeating tokens.
|
repetition_penalty (float, optional): The penalty factor for repeating tokens.
|
||||||
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty.
|
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty.
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(tokenizer, TokenizerWrapper):
|
||||||
|
tokenizer = TokenizerWrapper(tokenizer)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print("=" * 10)
|
print("=" * 10)
|
||||||
print("Prompt:", prompt)
|
print("Prompt:", prompt)
|
||||||
|
|
||||||
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
||||||
|
detokenizer = tokenizer.detokenizer
|
||||||
|
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
tokens = []
|
detokenizer.reset()
|
||||||
token_strings = []
|
|
||||||
skip = 0
|
|
||||||
REPLACEMENT_CHAR = "\ufffd"
|
|
||||||
|
|
||||||
for (token, prob), n in zip(
|
for (token, prob), n in zip(
|
||||||
generate_step(
|
generate_step(
|
||||||
@ -245,29 +245,21 @@ def generate(
|
|||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
if token == tokenizer.eos_token_id:
|
if token == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
tokens.append(token)
|
detokenizer.add_token(token)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
s = tokenizer.decode(tokens)
|
if formatter:
|
||||||
if not s:
|
# We have to finalize so that the prob corresponds to the last segment
|
||||||
continue
|
detokenizer.finalize()
|
||||||
elif formatter:
|
formatter(detokenizer.last_segment, prob.item())
|
||||||
formatter(s[skip:], prob.item())
|
else:
|
||||||
skip = len(s)
|
print(detokenizer.last_segment, end="", flush=True)
|
||||||
elif s[-1] != REPLACEMENT_CHAR:
|
|
||||||
print(s[skip:], end="", flush=True)
|
|
||||||
skip = len(s)
|
|
||||||
# Reset token cache at line break
|
|
||||||
if s[-1] == "\n":
|
|
||||||
tokens = []
|
|
||||||
token_strings.append(s)
|
|
||||||
skip = 0
|
|
||||||
|
|
||||||
token_count = n + 1
|
token_count = n + 1
|
||||||
token_strings.append(tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, ""))
|
detokenizer.finalize()
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(token_strings[-1][skip:], flush=True)
|
print(detokenizer.last_segment, flush=True)
|
||||||
gen_time = time.perf_counter() - tic
|
gen_time = time.perf_counter() - tic
|
||||||
print("=" * 10)
|
print("=" * 10)
|
||||||
if token_count == 0:
|
if token_count == 0:
|
||||||
@ -278,7 +270,7 @@ def generate(
|
|||||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||||
|
|
||||||
return "".join(token_strings)
|
return detokenizer.text
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
||||||
@ -384,8 +376,8 @@ def load(
|
|||||||
if adapter_path is not None:
|
if adapter_path is not None:
|
||||||
model = apply_lora_layers(model, adapter_path)
|
model = apply_lora_layers(model, adapter_path)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
tokenizer = load_tokenizer(model_path, tokenizer_config)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
@ -394,7 +386,7 @@ def fetch_from_hub(
|
|||||||
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
||||||
model = load_model(model_path, lazy)
|
model = load_model(model_path, lazy)
|
||||||
config = AutoConfig.from_pretrained(model_path)
|
config = AutoConfig.from_pretrained(model_path)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
tokenizer = load_tokenizer(model_path)
|
||||||
|
|
||||||
return model, config.to_dict(), tokenizer
|
return model, config.to_dict(), tokenizer
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user