More cache improvements (#1015)

* fix rotating kv cache for chat use case

* reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat

* nit in chat

* fix tests

* fix tests

* fix tests

* docs

* chat command

* comments + docs

* Define meta_state on all Cache implementations

* fixes + trim_prompt_cache api

* fix default model

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-10-07 20:45:51 -07:00
committed by GitHub
parent 9bc53fc210
commit fca087be49
43 changed files with 1151 additions and 691 deletions

View File

@@ -7,13 +7,14 @@ import time
import mlx.core as mx
from .utils import load, make_kv_caches
from .models.cache import make_prompt_cache, save_prompt_cache
from .utils import load
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(
description="Cache the KV cache of a prompt to be reused with mlx_lm.generate"
description="Cache the state of a prompt to be reused with mlx_lm.generate"
)
parser.add_argument(
"--model",
@@ -60,7 +61,9 @@ def setup_arg_parser():
help="Set the maximum key-value cache size",
)
parser.add_argument(
"--kv-cache-file", help="The file to save the KV caches in", required=True
"--prompt-cache-file",
help="The file to save the prompt cache in",
required=True,
)
parser.add_argument(
"--prompt",
@@ -115,7 +118,7 @@ def main():
else:
prompt = args.prompt
cache = make_kv_caches(model, args.max_kv_size)
cache = make_prompt_cache(model, args.max_kv_size)
y = mx.array(tokenizer.encode(prompt))
# Process the prompt
@@ -137,16 +140,12 @@ def main():
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
print("Saving...")
cache_dict = {}
for i, c in enumerate(cache):
cache_dict[f"{i}_keys"] = c.state[0][..., : c.offset, :]
cache_dict[f"{i}_values"] = c.state[1][..., : c.offset, :]
metadata = {}
metadata["model"] = args.model
metadata["chat_template"] = tokenizer.chat_template
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
metadata["max_kv_size"] = str(args.max_kv_size)
mx.save_safetensors(args.kv_cache_file, cache_dict, metadata)
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
save_prompt_cache(args.prompt_cache_file, cache, metadata)
if __name__ == "__main__":