diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 10a257f6..114a35e7 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -3,8 +3,6 @@ from functools import partial from transformers import AutoTokenizer -REPLACEMENT_CHAR = "\ufffd" - class StreamingDetokenizer: """The streaming detokenizer interface so that we can detokenize one token at a time. @@ -51,11 +49,9 @@ class StreamingDetokenizer: def last_segment(self): """Return the last segment of readable text since last time this property was accessed.""" text = self.text - if text and text[-1] != REPLACEMENT_CHAR: - segment = text[self.offset :] - self.offset = len(text) - return segment - return "" + segment = text[self.offset :] + self.offset = len(text) + return segment class NaiveStreamingDetokenizer(StreamingDetokenizer): @@ -132,7 +128,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer): self.tokens = [] def _flush(self): - text = self._unflushed.replace(self._sep, b" ").decode("utf-8") + text = self._unflushed.replace(self._sep, b" ").decode("utf-8", "replace") if not self.text and self.trim_space and text and text[0] == " ": text = text[1:] self.text += text @@ -202,7 +198,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): if is_added or self._byte_decoder[v[0]] == 32: current_text = bytearray( self._byte_decoder[c] for c in self._unflushed - ).decode("utf-8") + ).decode("utf-8", "replace") self.text += self._maybe_trim_space(current_text) if is_added: self.text += v @@ -214,7 +210,8 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): def finalize(self): current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode( - "utf-8" + "utf-8", + "replace", ) self.text += self._maybe_trim_space(current_text) self._unflushed = ""