diff --git a/llms/speculative_decoding/decoder.py b/llms/speculative_decoding/decoder.py index 07e0d25f..9762579c 100644 --- a/llms/speculative_decoding/decoder.py +++ b/llms/speculative_decoding/decoder.py @@ -173,14 +173,36 @@ class SpeculativeDecoder: n_steps += 1 + truncated = False for t in new_tokens.tolist(): if t == self.tokenizer.eos_id or n_generated >= max_tokens: + truncated = True break outputs.append(t) n_generated += 1 str_output = self.tokenizer.decode(outputs) - print(str_output[skip:], end="", flush=True) + self.color = True + if self.color and not truncated: + model_token = len(self.tokenizer.decode(outputs[-1])) + print( + "\033[34m" + + str_output[skip:-model_token] + + "\033[30m", + end="", + ) + print(str_output[-model_token:], end="", flush=True) + elif self.color and truncated: + if truncated: + print( + "\033[34m" + + str_output[skip:] + + "\033[30m", + end="", + ) + else: + print(str_output[skip:], end="", flush=True) + #print(str_output[skip:], end="", flush=True) skip = len(str_output) if n_generated >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id: @@ -298,14 +320,37 @@ class PromptLookupDecoder: n_steps += 1 + truncated = False for t in new_tokens.tolist(): if t == self.tokenizer.eos_id or n_generated >= max_tokens: + truncated = True break outputs.append(t) n_generated += 1 str_output = self.tokenizer.decode(outputs) - print(str_output[skip:], end="", flush=True) + #print(str_output[skip:], end="", flush=True) + + + if self.color and not truncated: + model_token = len(self.tokenizer.decode(outputs[-1])) + print( + "\033[34m" + + str_output[skip:-model_token] + + "\033[30m", + end="", + ) + print(str_output[-model_token:], end="", flush=True) + elif self.color and truncated: + if truncated: + print( + "\033[34m" + + str_output[skip:] + + "\033[30m", + end="", + ) + else: + print(str_output[skip:], end="", flush=True) skip = len(str_output) if n_generated >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id: