Generation refactor: part 2 (#1099)

* unify with stream_generate

* fixes

* nit

* some cleanup, warnings, tests

* fix test + faster min p + test

* version
This commit is contained in:
Awni Hannun
2024-11-23 11:47:06 -08:00
committed by GitHub
parent 004eb4cc9d
commit 0f135396ae
13 changed files with 184 additions and 197 deletions

View File

@@ -73,16 +73,16 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
def reset(self):
self.offset = 0
self._tokens = []
self.tokens = []
self._text = ""
self._current_tokens = []
self._current_text = ""
def add_token(self, token):
self._current_tokens.append(token)
self.tokens.append(token)
def finalize(self):
self._tokens.extend(self._current_tokens)
self._text += self._tokenizer.decode(self._current_tokens)
self._current_tokens = []
self._current_text = ""
@@ -97,16 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
):
self._current_text = self._current_text[:-1]
if self._current_text and self._current_text[-1] == "\n":
self._tokens.extend(self._current_tokens)
self._text += self._current_text
self._current_tokens.clear()
self._current_text = ""
return self._text + self._current_text
@property
def tokens(self):
return self._tokens
class SPMStreamingDetokenizer(StreamingDetokenizer):
"""A streaming detokenizer for SPM models.
@@ -143,6 +138,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
self.text += text
def add_token(self, token):
self.tokens.append(token)
v = self.tokenmap[token]
if v.startswith(self._sep):
self._flush()
@@ -200,6 +196,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
return current_text
def add_token(self, token):
self.tokens.append(token)
v = self.tokenmap[token]
is_added = token in self._added_ids
if is_added or self._byte_decoder[v[0]] == 32: