This commit is contained in:
Leon Ericsson 2024-01-01 12:53:11 +01:00
parent 3f750759d3
commit 7da89adef8

View File

@ -32,7 +32,7 @@ class PromptLookupDecoder:
x = sample(logits)
yield x
# Generate without prompt lookup decoding (for testing)
# Normal decoding w/o prompt lookup (for testing)
def generate(
self,
prompt,
@ -140,10 +140,11 @@ class PromptLookupDecoder:
n_decoding_steps += 1
# Check stop decodig criteria:
# Check stop decodig criteria and print accepted draft tokens.
for t in new_tokens.tolist()[:-1]:
if t == self.tokenizer.eos_id:
break
if color:
print(
"\033[34m" + self.tokenizer.decode([t]) + "\033[30m",