Load kv cache from a file

This commit is contained in:
Angelos Katharopoulos
2024-08-23 18:33:26 -07:00
parent 6731254e76
commit 920efec17e
5 changed files with 216 additions and 16 deletions

143
llms/mlx_lm/cache_prompt.py Normal file
View File

@@ -0,0 +1,143 @@
# Copyright © 2024 Apple Inc.
import argparse
import sys
import time
import mlx.core as mx
from .utils import load, make_kv_caches
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(
description="Cache the KV cache of a prompt to be reused with mlx_lm.generate"
)
parser.add_argument(
"--model",
type=str,
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--adapter-path",
type=str,
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Enable trusting remote code for tokenizer",
)
parser.add_argument(
"--eos-token",
type=str,
default=None,
help="End of sequence token for tokenizer",
)
parser.add_argument(
"--ignore-chat-template",
action="store_true",
help="Use the raw prompt without the tokenizer's chat template.",
)
parser.add_argument(
"--use-default-chat-template",
action="store_true",
help="Use the default chat template",
)
parser.add_argument(
"--cache-limit-gb",
type=int,
default=None,
help="Set the MLX cache limit in GB",
)
parser.add_argument(
"--max-kv-size",
type=int,
default=1024,
help="Set the maximum key-value cache size",
)
parser.add_argument(
"--kv-cache-file", help="The file to save the KV caches in", required=True
)
parser.add_argument(
"--prompt",
required=True,
help="Message to be processed by the model ('-' reads from stdin)",
)
return parser
def main():
parser = setup_arg_parser()
args = parser.parse_args()
if args.cache_limit_gb is not None:
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
# Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
model, tokenizer = load(
args.model,
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
)
args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Treat the prompt as a prefix assuming that the suffix will be
# provided at generation time.
test_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "<query>"}],
tokenize=False,
add_generation_prompt=True,
)
n = len(test_prompt) - test_prompt.index("<query>") - len("<query>")
prompt = prompt[:-n]
else:
prompt = args.prompt
cache = make_kv_caches(model, args.max_kv_size)
y = mx.array(tokenizer.encode(prompt))
# Process the prompt
processed = 0
step_size = 512
start = time.time()
max_msg_len = 0
while y.size > 0:
model(y[:step_size][None], cache=cache)
mx.eval([c.state for c in cache])
processed += min(y.size, step_size)
y = y[step_size:]
current = time.time()
speed = processed / (current - start)
msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
max_msg_len = max(max_msg_len, len(msg))
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)
print()
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
print("Saving...")
cache_dict = {}
for i, c in enumerate(cache):
cache_dict[f"{i}_keys"] = c.state[0]
cache_dict[f"{i}_values"] = c.state[1]
mx.save_safetensors(args.kv_cache_file, cache_dict)

View File

@@ -83,6 +83,12 @@ def setup_arg_parser():
default=1024,
help="Set the maximum key-value cache size",
)
parser.add_argument(
"--kv-cache-file",
type=str,
default=None,
help="A file containing saved KV caches to avoid recomputing them",
)
return parser
@@ -113,6 +119,24 @@ def colorprint_by_t0(s, t0):
colorprint(color, s)
def load_kv_cache_from_file(kv_cache_file):
if kv_cache_file is None:
return None
kv_cache = mx.load(kv_cache_file)
cache_per_layer = {}
for k, x in kv_cache.items():
layer, kv_type = k.split("_")
if layer not in cache_per_layer:
cache_per_layer[layer] = {}
cache_per_layer[layer][kv_type] = x
cache_history = [None] * len(cache_per_layer)
for layer, c in cache_per_layer.items():
cache_history[int(layer)] = (c["keys"], c["values"])
return cache_history
def main():
parser = setup_arg_parser()
args = parser.parse_args()
@@ -145,6 +169,16 @@ def main():
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Treat the prompt as a suffix assuming that the prefix is in the
# stored kv cache.
if args.kv_cache_file is not None:
test_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "<query>"}],
tokenize=False,
add_generation_prompt=True,
)
prompt = prompt[test_prompt.index("<query>") :]
else:
prompt = args.prompt
@@ -160,6 +194,7 @@ def main():
temp=args.temp,
top_p=args.top_p,
max_kv_size=args.max_kv_size,
cache_history=load_kv_cache_from_file(args.kv_cache_file),
)

View File

@@ -46,6 +46,7 @@ class KVCache:
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
@property
def state(self):
return self.keys, self.values
@@ -137,6 +138,7 @@ class RotatingKVCache:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
return self.keys, self.values
@property
def state(self):
return self.keys, self.values

View File

@@ -9,7 +9,7 @@ import shutil
import time
from pathlib import Path
from textwrap import dedent
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
import mlx.core as mx
import mlx.nn as nn
@@ -126,6 +126,26 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f
return logits
def make_kv_caches(
model: nn.Module, max_kv_size: Optional[int] = None
) -> List[Union[KVCache, RotatingKVCache]]:
if hasattr(model, "make_cache"):
return model.make_cache()
kv_heads = (
[model.n_kv_heads] * len(model.layers)
if isinstance(model.n_kv_heads, int)
else model.n_kv_heads
)
if max_kv_size is not None:
return [
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
for n in kv_heads
]
else:
return [KVCache(model.head_dim, n) for n in kv_heads]
def generate_step(
prompt: mx.array,
model: nn.Module,
@@ -138,6 +158,7 @@ def generate_step(
logit_bias: Optional[Dict[int, float]] = None,
prefill_step_size: int = 512,
max_kv_size: Optional[int] = None,
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@@ -194,21 +215,19 @@ def generate_step(
)
y = prompt
if hasattr(model, "make_cache"):
cache = model.make_cache()
else:
kv_heads = (
[model.n_kv_heads] * len(model.layers)
if isinstance(model.n_kv_heads, int)
else model.n_kv_heads
)
if max_kv_size is not None:
cache = [
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
for n in kv_heads
]
else:
cache = [KVCache(model.head_dim, n) for n in kv_heads]
# Create the KV cache for generation
cache = make_kv_caches(model, max_kv_size)
if cache_history is not None:
if len(cache_history) != len(cache):
raise ValueError("Wrong number of layers in the cache history")
# Set the history in the cache objects and evaluate them to prepare for
# generation.
for c, h in zip(cache, cache_history):
c.update_and_fetch(h[0], h[1])
mx.eval([c.state for c in cache])
repetition_context = prompt.tolist()

View File

@@ -31,6 +31,7 @@ setup(
},
entry_points={
"console_scripts": [
"mlx_lm.cache_prompt = mlx_lm.cache_prompt:main",
"mlx_lm.convert = mlx_lm.convert:main",
"mlx_lm.fuse = mlx_lm.fuse:main",
"mlx_lm.generate = mlx_lm.generate:main",