mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-02 13:40:48 +08:00
Generation refactor: part 2 (#1099)
* unify with stream_generate * fixes * nit * some cleanup, warnings, tests * fix test + faster min p + test * version
This commit is contained in:
@@ -73,16 +73,16 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||
|
||||
def reset(self):
|
||||
self.offset = 0
|
||||
self._tokens = []
|
||||
self.tokens = []
|
||||
self._text = ""
|
||||
self._current_tokens = []
|
||||
self._current_text = ""
|
||||
|
||||
def add_token(self, token):
|
||||
self._current_tokens.append(token)
|
||||
self.tokens.append(token)
|
||||
|
||||
def finalize(self):
|
||||
self._tokens.extend(self._current_tokens)
|
||||
self._text += self._tokenizer.decode(self._current_tokens)
|
||||
self._current_tokens = []
|
||||
self._current_text = ""
|
||||
@@ -97,16 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||
):
|
||||
self._current_text = self._current_text[:-1]
|
||||
if self._current_text and self._current_text[-1] == "\n":
|
||||
self._tokens.extend(self._current_tokens)
|
||||
self._text += self._current_text
|
||||
self._current_tokens.clear()
|
||||
self._current_text = ""
|
||||
return self._text + self._current_text
|
||||
|
||||
@property
|
||||
def tokens(self):
|
||||
return self._tokens
|
||||
|
||||
|
||||
class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||
"""A streaming detokenizer for SPM models.
|
||||
@@ -143,6 +138,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||
self.text += text
|
||||
|
||||
def add_token(self, token):
|
||||
self.tokens.append(token)
|
||||
v = self.tokenmap[token]
|
||||
if v.startswith(self._sep):
|
||||
self._flush()
|
||||
@@ -200,6 +196,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||
return current_text
|
||||
|
||||
def add_token(self, token):
|
||||
self.tokens.append(token)
|
||||
v = self.tokenmap[token]
|
||||
is_added = token in self._added_ids
|
||||
if is_added or self._byte_decoder[v[0]] == 32:
|
||||
|
Reference in New Issue
Block a user