Add the ability to load the KV cache from a file (#956)

This commit is contained in:
Angelos Katharopoulos 2024-08-28 22:11:45 -07:00 committed by GitHub
parent 7f8c961287
commit 1003a8b2dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 250 additions and 22 deletions

149
llms/mlx_lm/cache_prompt.py Normal file
View File

@ -0,0 +1,149 @@
# Copyright © 2024 Apple Inc.
import argparse
import json
import sys
import time
import mlx.core as mx
from .utils import load, make_kv_caches
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(
description="Cache the KV cache of a prompt to be reused with mlx_lm.generate"
)
parser.add_argument(
"--model",
type=str,
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--adapter-path",
type=str,
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Enable trusting remote code for tokenizer",
)
parser.add_argument(
"--eos-token",
type=str,
default=None,
help="End of sequence token for tokenizer",
)
parser.add_argument(
"--ignore-chat-template",
action="store_true",
help="Use the raw prompt without the tokenizer's chat template.",
)
parser.add_argument(
"--use-default-chat-template",
action="store_true",
help="Use the default chat template",
)
parser.add_argument(
"--cache-limit-gb",
type=int,
default=None,
help="Set the MLX cache limit in GB",
)
parser.add_argument(
"--max-kv-size",
type=int,
default=1024,
help="Set the maximum key-value cache size",
)
parser.add_argument(
"--kv-cache-file", help="The file to save the KV caches in", required=True
)
parser.add_argument(
"--prompt",
required=True,
help="Message to be processed by the model ('-' reads from stdin)",
)
return parser
def main():
parser = setup_arg_parser()
args = parser.parse_args()
if args.cache_limit_gb is not None:
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
# Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
model, tokenizer = load(
args.model,
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
)
args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Treat the prompt as a prefix assuming that the suffix will be
# provided at generation time.
test_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "<query>"}],
tokenize=False,
add_generation_prompt=True,
)
n = len(test_prompt) - test_prompt.index("<query>") - len("<query>")
prompt = prompt[:-n]
else:
prompt = args.prompt
cache = make_kv_caches(model, args.max_kv_size)
y = mx.array(tokenizer.encode(prompt))
# Process the prompt
processed = 0
step_size = 512
start = time.time()
max_msg_len = 0
while y.size > 0:
model(y[:step_size][None], cache=cache)
mx.eval([c.state for c in cache])
processed += min(y.size, step_size)
y = y[step_size:]
current = time.time()
speed = processed / (current - start)
msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
max_msg_len = max(max_msg_len, len(msg))
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)
print()
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
print("Saving...")
cache_dict = {}
for i, c in enumerate(cache):
cache_dict[f"{i}_keys"] = c.state[0]
cache_dict[f"{i}_values"] = c.state[1]
metadata = {}
metadata["model"] = args.model
metadata["chat_template"] = tokenizer.chat_template
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
metadata["max_kv_size"] = str(args.max_kv_size)
mx.save_safetensors(args.kv_cache_file, cache_dict, metadata)

View File

@ -1,17 +1,18 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import json
import mlx.core as mx
from .utils import generate, load
DEFAULT_MODEL_PATH = "mlx_model"
DEFAULT_PROMPT = "hello"
DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.6
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
DEFAULT_MAX_KV_SIZE = 1024
def setup_arg_parser():
@ -20,7 +21,6 @@ def setup_arg_parser():
parser.add_argument(
"--model",
type=str,
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
@ -80,9 +80,14 @@ def setup_arg_parser():
parser.add_argument(
"--max-kv-size",
type=int,
default=1024,
help="Set the maximum key-value cache size",
)
parser.add_argument(
"--kv-cache-file",
type=str,
default=None,
help="A file containing saved KV caches to avoid recomputing them",
)
return parser
@ -113,6 +118,24 @@ def colorprint_by_t0(s, t0):
colorprint(color, s)
def load_kv_cache_from_file(kv_cache_file):
if kv_cache_file is None:
return None, None
kv_cache, metadata = mx.load(kv_cache_file, return_metadata=True)
cache_per_layer = {}
for k, x in kv_cache.items():
layer, kv_type = k.split("_")
if layer not in cache_per_layer:
cache_per_layer[layer] = {}
cache_per_layer[layer][kv_type] = x
cache_history = [None] * len(cache_per_layer)
for layer, c in cache_per_layer.items():
cache_history[int(layer)] = (c["keys"], c["values"])
return cache_history, metadata
def main():
parser = setup_arg_parser()
args = parser.parse_args()
@ -122,13 +145,25 @@ def main():
if args.cache_limit_gb is not None:
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
# Load the kv cache and metadata if a kv cache file is provided
cache_history, metadata = load_kv_cache_from_file(args.kv_cache_file)
# Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
tokenizer_config = (
{} if cache_history is None else json.loads(metadata["tokenizer_config"])
)
if args.trust_remote_code:
tokenizer_config["trust_remote_code"] = True
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
# If no model path is provided then use the one in the kv cache history
model_path = args.model
if cache_history is not None and model_path is None:
model_path = metadata["model"]
model, tokenizer = load(
args.model,
model_path,
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
)
@ -136,6 +171,8 @@ def main():
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
elif tokenizer.chat_template is None:
tokenizer.chat_template = metadata["chat_template"]
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
@ -145,11 +182,30 @@ def main():
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Treat the prompt as a suffix assuming that the prefix is in the
# stored kv cache.
if cache_history is not None:
test_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "<query>"}],
tokenize=False,
add_generation_prompt=True,
)
prompt = prompt[test_prompt.index("<query>") :]
else:
prompt = args.prompt
formatter = colorprint_by_t0 if args.colorize else None
# Determine the max kv size from the kv cache or passed arguments
max_kv_size = args.max_kv_size
if max_kv_size is None:
max_kv_size = (
int(metadata["max_kv_size"])
if cache_history is not None
else DEFAULT_MAX_KV_SIZE
)
generate(
model,
tokenizer,
@ -159,7 +215,8 @@ def main():
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
max_kv_size=args.max_kv_size,
max_kv_size=max_kv_size,
cache_history=cache_history,
)

View File

@ -46,6 +46,7 @@ class KVCache:
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
@property
def state(self):
return self.keys, self.values
@ -137,6 +138,7 @@ class RotatingKVCache:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
return self.keys, self.values
@property
def state(self):
return self.keys, self.values

View File

@ -9,7 +9,7 @@ import shutil
import time
from pathlib import Path
from textwrap import dedent
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
import mlx.core as mx
import mlx.nn as nn
@ -126,6 +126,26 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f
return logits
def make_kv_caches(
model: nn.Module, max_kv_size: Optional[int] = None
) -> List[Union[KVCache, RotatingKVCache]]:
if hasattr(model, "make_cache"):
return model.make_cache()
kv_heads = (
[model.n_kv_heads] * len(model.layers)
if isinstance(model.n_kv_heads, int)
else model.n_kv_heads
)
if max_kv_size is not None:
return [
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
for n in kv_heads
]
else:
return [KVCache(model.head_dim, n) for n in kv_heads]
def generate_step(
prompt: mx.array,
model: nn.Module,
@ -138,6 +158,7 @@ def generate_step(
logit_bias: Optional[Dict[int, float]] = None,
prefill_step_size: int = 512,
max_kv_size: Optional[int] = None,
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@ -194,21 +215,19 @@ def generate_step(
)
y = prompt
if hasattr(model, "make_cache"):
cache = model.make_cache()
else:
kv_heads = (
[model.n_kv_heads] * len(model.layers)
if isinstance(model.n_kv_heads, int)
else model.n_kv_heads
)
if max_kv_size is not None:
cache = [
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
for n in kv_heads
]
else:
cache = [KVCache(model.head_dim, n) for n in kv_heads]
# Create the KV cache for generation
cache = make_kv_caches(model, max_kv_size)
if cache_history is not None:
if len(cache_history) != len(cache):
raise ValueError("Wrong number of layers in the cache history")
# Set the history in the cache objects and evaluate them to prepare for
# generation.
for c, h in zip(cache, cache_history):
c.update_and_fetch(h[0], h[1])
mx.eval([c.state for c in cache])
repetition_context = prompt.tolist()

View File

@ -31,6 +31,7 @@ setup(
},
entry_points={
"console_scripts": [
"mlx_lm.cache_prompt = mlx_lm.cache_prompt:main",
"mlx_lm.convert = mlx_lm.convert:main",
"mlx_lm.fuse = mlx_lm.fuse:main",
"mlx_lm.generate = mlx_lm.generate:main",