draft token coloring

This commit is contained in:
Leon Ericsson
2024-01-07 11:58:37 +01:00
parent 37c5555616
commit 692dd6c27c

View File

@@ -173,14 +173,36 @@ class SpeculativeDecoder:
n_steps += 1
truncated = False
for t in new_tokens.tolist():
if t == self.tokenizer.eos_id or n_generated >= max_tokens:
truncated = True
break
outputs.append(t)
n_generated += 1
str_output = self.tokenizer.decode(outputs)
print(str_output[skip:], end="", flush=True)
self.color = True
if self.color and not truncated:
model_token = len(self.tokenizer.decode(outputs[-1]))
print(
"\033[34m"
+ str_output[skip:-model_token]
+ "\033[30m",
end="",
)
print(str_output[-model_token:], end="", flush=True)
elif self.color and truncated:
if truncated:
print(
"\033[34m"
+ str_output[skip:]
+ "\033[30m",
end="",
)
else:
print(str_output[skip:], end="", flush=True)
#print(str_output[skip:], end="", flush=True)
skip = len(str_output)
if n_generated >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
@@ -298,14 +320,37 @@ class PromptLookupDecoder:
n_steps += 1
truncated = False
for t in new_tokens.tolist():
if t == self.tokenizer.eos_id or n_generated >= max_tokens:
truncated = True
break
outputs.append(t)
n_generated += 1
str_output = self.tokenizer.decode(outputs)
print(str_output[skip:], end="", flush=True)
#print(str_output[skip:], end="", flush=True)
if self.color and not truncated:
model_token = len(self.tokenizer.decode(outputs[-1]))
print(
"\033[34m"
+ str_output[skip:-model_token]
+ "\033[30m",
end="",
)
print(str_output[-model_token:], end="", flush=True)
elif self.color and truncated:
if truncated:
print(
"\033[34m"
+ str_output[skip:]
+ "\033[30m",
end="",
)
else:
print(str_output[skip:], end="", flush=True)
skip = len(str_output)
if n_generated >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id: