Fix the cache_prompt (#979)

This commit is contained in:
Angelos Katharopoulos 2024-09-06 20:19:27 -07:00 committed by GitHub
parent bd29aec299
commit 324184d670
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -139,8 +139,8 @@ def main():
print("Saving...")
cache_dict = {}
for i, c in enumerate(cache):
cache_dict[f"{i}_keys"] = c.state[0]
cache_dict[f"{i}_values"] = c.state[1]
cache_dict[f"{i}_keys"] = c.state[0][..., : c.offset, :]
cache_dict[f"{i}_values"] = c.state[1][..., : c.offset, :]
metadata = {}
metadata["model"] = args.model
metadata["chat_template"] = tokenizer.chat_template