Added more fixes

This commit is contained in:
paramthakkar123
2025-04-07 08:59:07 +05:30
parent 298178d669
commit 4304f5aaf5
4 changed files with 24 additions and 14 deletions

View File

@@ -160,16 +160,16 @@ class SpeculativeDecoder:
)
n_accepted += num_to_accept
n_draft += draft_tokens.size
n_draft += len(draft_tokens)
# Rewind the cache for unaccepted tokens:
if (n := draft_tokens.size) > num_to_accept:
self.draft_model.truncate_cache(n - new_tokens.size)
self.model.truncate_cache(n - new_tokens.size + 1)
if (n := len(draft_tokens)) > num_to_accept:
self.draft_model.truncate_cache(n - len(new_tokens))
self.model.truncate_cache(n - len(new_tokens) + 1)
n_steps += 1
for t in new_tokens.tolist():
for t in list(new_tokens):
if t == self.tokenizer.eos_id or ntoks >= max_tokens:
break
outputs.append(t)
@@ -181,7 +181,7 @@ class SpeculativeDecoder:
if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
break
draft_inputs = new_tokens[max(new_tokens.size - 2, 0) :]
draft_inputs = new_tokens[max(len(new_tokens) - 2, 0) :]
inputs = draft_inputs[-1:]
print(self.tokenizer.decode(outputs)[skip:], end="", flush=True)