diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index fe088118..9829efb4 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -139,8 +139,8 @@ def main(): print("Saving...") cache_dict = {} for i, c in enumerate(cache): - cache_dict[f"{i}_keys"] = c.state[0] - cache_dict[f"{i}_values"] = c.state[1] + cache_dict[f"{i}_keys"] = c.state[0][..., : c.offset, :] + cache_dict[f"{i}_values"] = c.state[1][..., : c.offset, :] metadata = {} metadata["model"] = args.model metadata["chat_template"] = tokenizer.chat_template