From 845efddc8cb4578fe008c2ad0c26ec595e7f6b1e Mon Sep 17 00:00:00 2001 From: Billel Mokeddem Date: Tue, 17 Dec 2024 21:54:29 +0400 Subject: [PATCH] Fix decoding manually added tokens (#1164) * Fix decoding manually added tokens * fix + test * nit * nit * no lag bpe --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/tokenizer_utils.py | 44 +++++++++++++++++++--------------- llms/tests/test_tokenizers.py | 4 ++++ 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 8251e62f..ca3d6c06 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -127,23 +127,23 @@ class SPMStreamingDetokenizer(StreamingDetokenizer): self.text = "" self.tokens = [] - def _flush(self): + def _try_flush(self, force=False): text = self._unflushed.replace(self._sep, b" ").decode("utf-8", "replace") + if not force and text.endswith("\ufffd"): + return if not self.text and self.trim_space and text and text[0] == " ": text = text[1:] self.text += text + self._unflushed = b"" def add_token(self, token): self.tokens.append(token) v = self.tokenmap[token] - if v.startswith(self._sep): - self._flush() - self._unflushed = v - else: - self._unflushed += v + self._unflushed += v + self._try_flush() def finalize(self): - self._flush() + self._try_flush(force=True) self._unflushed = b"" @@ -158,7 +158,6 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): _space_matches = (".", "?", "!", ",", "n't", "'m", "'s", "'ve", "'re") def __init__(self, tokenizer): - self.clean_spaces = tokenizer.clean_up_tokenization_spaces # Extract the tokens in a list from id to text @@ -172,14 +171,22 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): # https://github.com/openai/gpt-2/blob/master/src/encoder.py self.make_byte_decoder() - self._added_ids = set(tokenizer.added_tokens_decoder.keys()) - def reset(self): self.offset = 0 self._unflushed = "" self.text = "" self.tokens = [] + def _decode_bytes(self, seq): + barr = bytearray() + for c in seq: + res = self._byte_decoder.get(c, False) + if res: + barr.append(res) + else: + barr.extend(bytes(c, "utf-8")) + return barr.decode("utf-8", "replace") + def _maybe_trim_space(self, current_text): if len(current_text) == 0: return current_text @@ -194,15 +201,14 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): def add_token(self, token): self.tokens.append(token) v = self.tokenmap[token] - is_added = token in self._added_ids - if not is_added: - self._unflushed += v - text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode( - "utf-8", "replace" - ) - if is_added: - text += v - if not text.endswith("\ufffd"): + self._unflushed += v + text = self._decode_bytes(self._unflushed) + + # For multi-byte utf-8 wait until they are complete + # For single spaces wait until the next token to clean it if needed + if not text.endswith("\ufffd") and not ( + len(v) == 1 and self._byte_decoder[v[0]] == 32 + ): self.text += self._maybe_trim_space(text) self._unflushed = "" diff --git a/llms/tests/test_tokenizers.py b/llms/tests/test_tokenizers.py index db6b9f9e..3009d1b1 100644 --- a/llms/tests/test_tokenizers.py +++ b/llms/tests/test_tokenizers.py @@ -58,6 +58,9 @@ class TestTokenizers(unittest.TestCase): tokens = tokenizer.encode("import 'package:flutter/material.dart';") check(tokens) + tokens = tokenizer.encode("hello\nworld") + check(tokens) + def test_tokenizers(self): tokenizer_repos = [ ("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer), @@ -65,6 +68,7 @@ class TestTokenizers(unittest.TestCase): ("mlx-community/Phi-3.5-mini-instruct-4bit", SPMStreamingDetokenizer), ("mlx-community/Mistral-7B-Instruct-v0.3", SPMStreamingDetokenizer), ("mlx-community/Llama-3.2-1B-Instruct-4bit", BPEStreamingDetokenizer), + ("mlx-community/Falcon3-7B-Instruct-4bit", BPEStreamingDetokenizer), ] for tokenizer_repo, expected_detokenizer in tokenizer_repos: with self.subTest(tokenizer=tokenizer_repo):