mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
draft token coloring
This commit is contained in:
@@ -173,14 +173,36 @@ class SpeculativeDecoder:
|
|||||||
|
|
||||||
n_steps += 1
|
n_steps += 1
|
||||||
|
|
||||||
|
truncated = False
|
||||||
for t in new_tokens.tolist():
|
for t in new_tokens.tolist():
|
||||||
if t == self.tokenizer.eos_id or n_generated >= max_tokens:
|
if t == self.tokenizer.eos_id or n_generated >= max_tokens:
|
||||||
|
truncated = True
|
||||||
break
|
break
|
||||||
outputs.append(t)
|
outputs.append(t)
|
||||||
n_generated += 1
|
n_generated += 1
|
||||||
|
|
||||||
str_output = self.tokenizer.decode(outputs)
|
str_output = self.tokenizer.decode(outputs)
|
||||||
print(str_output[skip:], end="", flush=True)
|
self.color = True
|
||||||
|
if self.color and not truncated:
|
||||||
|
model_token = len(self.tokenizer.decode(outputs[-1]))
|
||||||
|
print(
|
||||||
|
"\033[34m"
|
||||||
|
+ str_output[skip:-model_token]
|
||||||
|
+ "\033[30m",
|
||||||
|
end="",
|
||||||
|
)
|
||||||
|
print(str_output[-model_token:], end="", flush=True)
|
||||||
|
elif self.color and truncated:
|
||||||
|
if truncated:
|
||||||
|
print(
|
||||||
|
"\033[34m"
|
||||||
|
+ str_output[skip:]
|
||||||
|
+ "\033[30m",
|
||||||
|
end="",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(str_output[skip:], end="", flush=True)
|
||||||
|
#print(str_output[skip:], end="", flush=True)
|
||||||
skip = len(str_output)
|
skip = len(str_output)
|
||||||
|
|
||||||
if n_generated >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
|
if n_generated >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
|
||||||
@@ -298,14 +320,37 @@ class PromptLookupDecoder:
|
|||||||
|
|
||||||
n_steps += 1
|
n_steps += 1
|
||||||
|
|
||||||
|
truncated = False
|
||||||
for t in new_tokens.tolist():
|
for t in new_tokens.tolist():
|
||||||
if t == self.tokenizer.eos_id or n_generated >= max_tokens:
|
if t == self.tokenizer.eos_id or n_generated >= max_tokens:
|
||||||
|
truncated = True
|
||||||
break
|
break
|
||||||
outputs.append(t)
|
outputs.append(t)
|
||||||
n_generated += 1
|
n_generated += 1
|
||||||
|
|
||||||
str_output = self.tokenizer.decode(outputs)
|
str_output = self.tokenizer.decode(outputs)
|
||||||
print(str_output[skip:], end="", flush=True)
|
#print(str_output[skip:], end="", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
if self.color and not truncated:
|
||||||
|
model_token = len(self.tokenizer.decode(outputs[-1]))
|
||||||
|
print(
|
||||||
|
"\033[34m"
|
||||||
|
+ str_output[skip:-model_token]
|
||||||
|
+ "\033[30m",
|
||||||
|
end="",
|
||||||
|
)
|
||||||
|
print(str_output[-model_token:], end="", flush=True)
|
||||||
|
elif self.color and truncated:
|
||||||
|
if truncated:
|
||||||
|
print(
|
||||||
|
"\033[34m"
|
||||||
|
+ str_output[skip:]
|
||||||
|
+ "\033[30m",
|
||||||
|
end="",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(str_output[skip:], end="", flush=True)
|
||||||
skip = len(str_output)
|
skip = len(str_output)
|
||||||
|
|
||||||
if n_generated >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
|
if n_generated >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
|
||||||
|
Reference in New Issue
Block a user