Fix the cache_prompt (#979)

This commit is contained in:
Angelos Katharopoulos
2024-09-06 20:19:27 -07:00
committed by GitHub
parent bd29aec299
commit 324184d670

View File

@@ -139,8 +139,8 @@ def main():
print("Saving...")
cache_dict = {}
for i, c in enumerate(cache):
cache_dict[f"{i}_keys"] = c.state[0]
cache_dict[f"{i}_values"] = c.state[1]
cache_dict[f"{i}_keys"] = c.state[0][..., : c.offset, :]
cache_dict[f"{i}_values"] = c.state[1][..., : c.offset, :]
metadata = {}
metadata["model"] = args.model
metadata["chat_template"] = tokenizer.chat_template