From 324184d670ec11916a5e92314171d497b312eefe Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 6 Sep 2024 20:19:27 -0700 Subject: [PATCH] Fix the cache_prompt (#979) --- llms/mlx_lm/cache_prompt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index fe088118..9829efb4 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -139,8 +139,8 @@ def main(): print("Saving...") cache_dict = {} for i, c in enumerate(cache): - cache_dict[f"{i}_keys"] = c.state[0] - cache_dict[f"{i}_values"] = c.state[1] + cache_dict[f"{i}_keys"] = c.state[0][..., : c.offset, :] + cache_dict[f"{i}_values"] = c.state[1][..., : c.offset, :] metadata = {} metadata["model"] = args.model metadata["chat_template"] = tokenizer.chat_template