Add the ability to load the KV cache from a file (#956)

This commit is contained in:
Angelos Katharopoulos
2024-08-28 22:11:45 -07:00
committed by GitHub
parent 7f8c961287
commit 1003a8b2dd
5 changed files with 250 additions and 22 deletions

View File

@@ -1,17 +1,18 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import json
import mlx.core as mx
from .utils import generate, load
DEFAULT_MODEL_PATH = "mlx_model"
DEFAULT_PROMPT = "hello"
DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.6
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
DEFAULT_MAX_KV_SIZE = 1024
def setup_arg_parser():
@@ -20,7 +21,6 @@ def setup_arg_parser():
parser.add_argument(
"--model",
type=str,
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
@@ -80,9 +80,14 @@ def setup_arg_parser():
parser.add_argument(
"--max-kv-size",
type=int,
default=1024,
help="Set the maximum key-value cache size",
)
parser.add_argument(
"--kv-cache-file",
type=str,
default=None,
help="A file containing saved KV caches to avoid recomputing them",
)
return parser
@@ -113,6 +118,24 @@ def colorprint_by_t0(s, t0):
colorprint(color, s)
def load_kv_cache_from_file(kv_cache_file):
if kv_cache_file is None:
return None, None
kv_cache, metadata = mx.load(kv_cache_file, return_metadata=True)
cache_per_layer = {}
for k, x in kv_cache.items():
layer, kv_type = k.split("_")
if layer not in cache_per_layer:
cache_per_layer[layer] = {}
cache_per_layer[layer][kv_type] = x
cache_history = [None] * len(cache_per_layer)
for layer, c in cache_per_layer.items():
cache_history[int(layer)] = (c["keys"], c["values"])
return cache_history, metadata
def main():
parser = setup_arg_parser()
args = parser.parse_args()
@@ -122,13 +145,25 @@ def main():
if args.cache_limit_gb is not None:
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
# Load the kv cache and metadata if a kv cache file is provided
cache_history, metadata = load_kv_cache_from_file(args.kv_cache_file)
# Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
tokenizer_config = (
{} if cache_history is None else json.loads(metadata["tokenizer_config"])
)
if args.trust_remote_code:
tokenizer_config["trust_remote_code"] = True
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
# If no model path is provided then use the one in the kv cache history
model_path = args.model
if cache_history is not None and model_path is None:
model_path = metadata["model"]
model, tokenizer = load(
args.model,
model_path,
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
)
@@ -136,6 +171,8 @@ def main():
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
elif tokenizer.chat_template is None:
tokenizer.chat_template = metadata["chat_template"]
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
@@ -145,11 +182,30 @@ def main():
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Treat the prompt as a suffix assuming that the prefix is in the
# stored kv cache.
if cache_history is not None:
test_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "<query>"}],
tokenize=False,
add_generation_prompt=True,
)
prompt = prompt[test_prompt.index("<query>") :]
else:
prompt = args.prompt
formatter = colorprint_by_t0 if args.colorize else None
# Determine the max kv size from the kv cache or passed arguments
max_kv_size = args.max_kv_size
if max_kv_size is None:
max_kv_size = (
int(metadata["max_kv_size"])
if cache_history is not None
else DEFAULT_MAX_KV_SIZE
)
generate(
model,
tokenizer,
@@ -159,7 +215,8 @@ def main():
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
max_kv_size=args.max_kv_size,
max_kv_size=max_kv_size,
cache_history=cache_history,
)