Add the ability to load the KV cache from a file (#956)

This commit is contained in:
Angelos Katharopoulos 2024-08-28 22:11:45 -07:00 committed by GitHub
parent 7f8c961287
commit 1003a8b2dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 250 additions and 22 deletions

149
llms/mlx_lm/cache_prompt.py Normal file
View File

@ -0,0 +1,149 @@
# Copyright © 2024 Apple Inc.
import argparse
import json
import sys
import time
import mlx.core as mx
from .utils import load, make_kv_caches
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(
description="Cache the KV cache of a prompt to be reused with mlx_lm.generate"
)
parser.add_argument(
"--model",
type=str,
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--adapter-path",
type=str,
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Enable trusting remote code for tokenizer",
)
parser.add_argument(
"--eos-token",
type=str,
default=None,
help="End of sequence token for tokenizer",
)
parser.add_argument(
"--ignore-chat-template",
action="store_true",
help="Use the raw prompt without the tokenizer's chat template.",
)
parser.add_argument(
"--use-default-chat-template",
action="store_true",
help="Use the default chat template",
)
parser.add_argument(
"--cache-limit-gb",
type=int,
default=None,
help="Set the MLX cache limit in GB",
)
parser.add_argument(
"--max-kv-size",
type=int,
default=1024,
help="Set the maximum key-value cache size",
)
parser.add_argument(
"--kv-cache-file", help="The file to save the KV caches in", required=True
)
parser.add_argument(
"--prompt",
required=True,
help="Message to be processed by the model ('-' reads from stdin)",
)
return parser
def main():
parser = setup_arg_parser()
args = parser.parse_args()
if args.cache_limit_gb is not None:
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
# Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
model, tokenizer = load(
args.model,
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
)
args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Treat the prompt as a prefix assuming that the suffix will be
# provided at generation time.
test_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "<query>"}],
tokenize=False,
add_generation_prompt=True,
)
n = len(test_prompt) - test_prompt.index("<query>") - len("<query>")
prompt = prompt[:-n]
else:
prompt = args.prompt
cache = make_kv_caches(model, args.max_kv_size)
y = mx.array(tokenizer.encode(prompt))
# Process the prompt
processed = 0
step_size = 512
start = time.time()
max_msg_len = 0
while y.size > 0:
model(y[:step_size][None], cache=cache)
mx.eval([c.state for c in cache])
processed += min(y.size, step_size)
y = y[step_size:]
current = time.time()
speed = processed / (current - start)
msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
max_msg_len = max(max_msg_len, len(msg))
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)
print()
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
print("Saving...")
cache_dict = {}
for i, c in enumerate(cache):
cache_dict[f"{i}_keys"] = c.state[0]
cache_dict[f"{i}_values"] = c.state[1]
metadata = {}
metadata["model"] = args.model
metadata["chat_template"] = tokenizer.chat_template
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
metadata["max_kv_size"] = str(args.max_kv_size)
mx.save_safetensors(args.kv_cache_file, cache_dict, metadata)

View File

@ -1,17 +1,18 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import argparse import argparse
import json
import mlx.core as mx import mlx.core as mx
from .utils import generate, load from .utils import generate, load
DEFAULT_MODEL_PATH = "mlx_model"
DEFAULT_PROMPT = "hello" DEFAULT_PROMPT = "hello"
DEFAULT_MAX_TOKENS = 100 DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.6 DEFAULT_TEMP = 0.6
DEFAULT_TOP_P = 1.0 DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0 DEFAULT_SEED = 0
DEFAULT_MAX_KV_SIZE = 1024
def setup_arg_parser(): def setup_arg_parser():
@ -20,7 +21,6 @@ def setup_arg_parser():
parser.add_argument( parser.add_argument(
"--model", "--model",
type=str, type=str,
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.", help="The path to the local model directory or Hugging Face repo.",
) )
parser.add_argument( parser.add_argument(
@ -80,9 +80,14 @@ def setup_arg_parser():
parser.add_argument( parser.add_argument(
"--max-kv-size", "--max-kv-size",
type=int, type=int,
default=1024,
help="Set the maximum key-value cache size", help="Set the maximum key-value cache size",
) )
parser.add_argument(
"--kv-cache-file",
type=str,
default=None,
help="A file containing saved KV caches to avoid recomputing them",
)
return parser return parser
@ -113,6 +118,24 @@ def colorprint_by_t0(s, t0):
colorprint(color, s) colorprint(color, s)
def load_kv_cache_from_file(kv_cache_file):
if kv_cache_file is None:
return None, None
kv_cache, metadata = mx.load(kv_cache_file, return_metadata=True)
cache_per_layer = {}
for k, x in kv_cache.items():
layer, kv_type = k.split("_")
if layer not in cache_per_layer:
cache_per_layer[layer] = {}
cache_per_layer[layer][kv_type] = x
cache_history = [None] * len(cache_per_layer)
for layer, c in cache_per_layer.items():
cache_history[int(layer)] = (c["keys"], c["values"])
return cache_history, metadata
def main(): def main():
parser = setup_arg_parser() parser = setup_arg_parser()
args = parser.parse_args() args = parser.parse_args()
@ -122,13 +145,25 @@ def main():
if args.cache_limit_gb is not None: if args.cache_limit_gb is not None:
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
# Load the kv cache and metadata if a kv cache file is provided
cache_history, metadata = load_kv_cache_from_file(args.kv_cache_file)
# Building tokenizer_config # Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} tokenizer_config = (
{} if cache_history is None else json.loads(metadata["tokenizer_config"])
)
if args.trust_remote_code:
tokenizer_config["trust_remote_code"] = True
if args.eos_token is not None: if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token tokenizer_config["eos_token"] = args.eos_token
# If no model path is provided then use the one in the kv cache history
model_path = args.model
if cache_history is not None and model_path is None:
model_path = metadata["model"]
model, tokenizer = load( model, tokenizer = load(
args.model, model_path,
adapter_path=args.adapter_path, adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config, tokenizer_config=tokenizer_config,
) )
@ -136,6 +171,8 @@ def main():
if args.use_default_chat_template: if args.use_default_chat_template:
if tokenizer.chat_template is None: if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template tokenizer.chat_template = tokenizer.default_chat_template
elif tokenizer.chat_template is None:
tokenizer.chat_template = metadata["chat_template"]
if not args.ignore_chat_template and ( if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template") hasattr(tokenizer, "apply_chat_template")
@ -145,11 +182,30 @@ def main():
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
# Treat the prompt as a suffix assuming that the prefix is in the
# stored kv cache.
if cache_history is not None:
test_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "<query>"}],
tokenize=False,
add_generation_prompt=True,
)
prompt = prompt[test_prompt.index("<query>") :]
else: else:
prompt = args.prompt prompt = args.prompt
formatter = colorprint_by_t0 if args.colorize else None formatter = colorprint_by_t0 if args.colorize else None
# Determine the max kv size from the kv cache or passed arguments
max_kv_size = args.max_kv_size
if max_kv_size is None:
max_kv_size = (
int(metadata["max_kv_size"])
if cache_history is not None
else DEFAULT_MAX_KV_SIZE
)
generate( generate(
model, model,
tokenizer, tokenizer,
@ -159,7 +215,8 @@ def main():
formatter=formatter, formatter=formatter,
temp=args.temp, temp=args.temp,
top_p=args.top_p, top_p=args.top_p,
max_kv_size=args.max_kv_size, max_kv_size=max_kv_size,
cache_history=cache_history,
) )

View File

@ -46,6 +46,7 @@ class KVCache:
self.values[..., prev : self.offset, :] = values self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
@property
def state(self): def state(self):
return self.keys, self.values return self.keys, self.values
@ -137,6 +138,7 @@ class RotatingKVCache:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
return self.keys, self.values return self.keys, self.values
@property
def state(self): def state(self):
return self.keys, self.values return self.keys, self.values

View File

@ -9,7 +9,7 @@ import shutil
import time import time
from pathlib import Path from pathlib import Path
from textwrap import dedent from textwrap import dedent
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Type, Union from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -126,6 +126,26 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f
return logits return logits
def make_kv_caches(
model: nn.Module, max_kv_size: Optional[int] = None
) -> List[Union[KVCache, RotatingKVCache]]:
if hasattr(model, "make_cache"):
return model.make_cache()
kv_heads = (
[model.n_kv_heads] * len(model.layers)
if isinstance(model.n_kv_heads, int)
else model.n_kv_heads
)
if max_kv_size is not None:
return [
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
for n in kv_heads
]
else:
return [KVCache(model.head_dim, n) for n in kv_heads]
def generate_step( def generate_step(
prompt: mx.array, prompt: mx.array,
model: nn.Module, model: nn.Module,
@ -138,6 +158,7 @@ def generate_step(
logit_bias: Optional[Dict[int, float]] = None, logit_bias: Optional[Dict[int, float]] = None,
prefill_step_size: int = 512, prefill_step_size: int = 512,
max_kv_size: Optional[int] = None, max_kv_size: Optional[int] = None,
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]: ) -> Generator[Tuple[mx.array, mx.array], None, None]:
""" """
A generator producing token ids based on the given prompt from the model. A generator producing token ids based on the given prompt from the model.
@ -194,21 +215,19 @@ def generate_step(
) )
y = prompt y = prompt
if hasattr(model, "make_cache"):
cache = model.make_cache() # Create the KV cache for generation
else: cache = make_kv_caches(model, max_kv_size)
kv_heads = (
[model.n_kv_heads] * len(model.layers) if cache_history is not None:
if isinstance(model.n_kv_heads, int) if len(cache_history) != len(cache):
else model.n_kv_heads raise ValueError("Wrong number of layers in the cache history")
)
if max_kv_size is not None: # Set the history in the cache objects and evaluate them to prepare for
cache = [ # generation.
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4) for c, h in zip(cache, cache_history):
for n in kv_heads c.update_and_fetch(h[0], h[1])
] mx.eval([c.state for c in cache])
else:
cache = [KVCache(model.head_dim, n) for n in kv_heads]
repetition_context = prompt.tolist() repetition_context = prompt.tolist()

View File

@ -31,6 +31,7 @@ setup(
}, },
entry_points={ entry_points={
"console_scripts": [ "console_scripts": [
"mlx_lm.cache_prompt = mlx_lm.cache_prompt:main",
"mlx_lm.convert = mlx_lm.convert:main", "mlx_lm.convert = mlx_lm.convert:main",
"mlx_lm.fuse = mlx_lm.fuse:main", "mlx_lm.fuse = mlx_lm.fuse:main",
"mlx_lm.generate = mlx_lm.generate:main", "mlx_lm.generate = mlx_lm.generate:main",