mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Add the ability to load the KV cache from a file (#956)
This commit is contained in:
parent
7f8c961287
commit
1003a8b2dd
149
llms/mlx_lm/cache_prompt.py
Normal file
149
llms/mlx_lm/cache_prompt.py
Normal file
@ -0,0 +1,149 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .utils import load, make_kv_caches
|
||||
|
||||
|
||||
def setup_arg_parser():
|
||||
"""Set up and return the argument parser."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Cache the KV cache of a prompt to be reused with mlx_lm.generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="mlx_model",
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-path",
|
||||
type=str,
|
||||
help="Optional path for the trained adapter weights and config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
help="Enable trusting remote code for tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eos-token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="End of sequence token for tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ignore-chat-template",
|
||||
action="store_true",
|
||||
help="Use the raw prompt without the tokenizer's chat template.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-default-chat-template",
|
||||
action="store_true",
|
||||
help="Use the default chat template",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache-limit-gb",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the MLX cache limit in GB",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-kv-size",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Set the maximum key-value cache size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-cache-file", help="The file to save the KV caches in", required=True
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
required=True,
|
||||
help="Message to be processed by the model ('-' reads from stdin)",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.cache_limit_gb is not None:
|
||||
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
|
||||
|
||||
# Building tokenizer_config
|
||||
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
||||
if args.eos_token is not None:
|
||||
tokenizer_config["eos_token"] = args.eos_token
|
||||
|
||||
model, tokenizer = load(
|
||||
args.model,
|
||||
adapter_path=args.adapter_path,
|
||||
tokenizer_config=tokenizer_config,
|
||||
)
|
||||
|
||||
args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt
|
||||
|
||||
if args.use_default_chat_template:
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
|
||||
if not args.ignore_chat_template and (
|
||||
hasattr(tokenizer, "apply_chat_template")
|
||||
and tokenizer.chat_template is not None
|
||||
):
|
||||
messages = [{"role": "user", "content": args.prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Treat the prompt as a prefix assuming that the suffix will be
|
||||
# provided at generation time.
|
||||
test_prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": "<query>"}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
n = len(test_prompt) - test_prompt.index("<query>") - len("<query>")
|
||||
prompt = prompt[:-n]
|
||||
else:
|
||||
prompt = args.prompt
|
||||
|
||||
cache = make_kv_caches(model, args.max_kv_size)
|
||||
y = mx.array(tokenizer.encode(prompt))
|
||||
|
||||
# Process the prompt
|
||||
processed = 0
|
||||
step_size = 512
|
||||
start = time.time()
|
||||
max_msg_len = 0
|
||||
while y.size > 0:
|
||||
model(y[:step_size][None], cache=cache)
|
||||
mx.eval([c.state for c in cache])
|
||||
processed += min(y.size, step_size)
|
||||
y = y[step_size:]
|
||||
current = time.time()
|
||||
speed = processed / (current - start)
|
||||
msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
|
||||
max_msg_len = max(max_msg_len, len(msg))
|
||||
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)
|
||||
print()
|
||||
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
|
||||
|
||||
print("Saving...")
|
||||
cache_dict = {}
|
||||
for i, c in enumerate(cache):
|
||||
cache_dict[f"{i}_keys"] = c.state[0]
|
||||
cache_dict[f"{i}_values"] = c.state[1]
|
||||
metadata = {}
|
||||
metadata["model"] = args.model
|
||||
metadata["chat_template"] = tokenizer.chat_template
|
||||
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
||||
metadata["max_kv_size"] = str(args.max_kv_size)
|
||||
mx.save_safetensors(args.kv_cache_file, cache_dict, metadata)
|
@ -1,17 +1,18 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .utils import generate, load
|
||||
|
||||
DEFAULT_MODEL_PATH = "mlx_model"
|
||||
DEFAULT_PROMPT = "hello"
|
||||
DEFAULT_MAX_TOKENS = 100
|
||||
DEFAULT_TEMP = 0.6
|
||||
DEFAULT_TOP_P = 1.0
|
||||
DEFAULT_SEED = 0
|
||||
DEFAULT_MAX_KV_SIZE = 1024
|
||||
|
||||
|
||||
def setup_arg_parser():
|
||||
@ -20,7 +21,6 @@ def setup_arg_parser():
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="mlx_model",
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -80,9 +80,14 @@ def setup_arg_parser():
|
||||
parser.add_argument(
|
||||
"--max-kv-size",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Set the maximum key-value cache size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-cache-file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A file containing saved KV caches to avoid recomputing them",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@ -113,6 +118,24 @@ def colorprint_by_t0(s, t0):
|
||||
colorprint(color, s)
|
||||
|
||||
|
||||
def load_kv_cache_from_file(kv_cache_file):
|
||||
if kv_cache_file is None:
|
||||
return None, None
|
||||
|
||||
kv_cache, metadata = mx.load(kv_cache_file, return_metadata=True)
|
||||
cache_per_layer = {}
|
||||
for k, x in kv_cache.items():
|
||||
layer, kv_type = k.split("_")
|
||||
if layer not in cache_per_layer:
|
||||
cache_per_layer[layer] = {}
|
||||
cache_per_layer[layer][kv_type] = x
|
||||
|
||||
cache_history = [None] * len(cache_per_layer)
|
||||
for layer, c in cache_per_layer.items():
|
||||
cache_history[int(layer)] = (c["keys"], c["values"])
|
||||
return cache_history, metadata
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_arg_parser()
|
||||
args = parser.parse_args()
|
||||
@ -122,13 +145,25 @@ def main():
|
||||
if args.cache_limit_gb is not None:
|
||||
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
|
||||
|
||||
# Load the kv cache and metadata if a kv cache file is provided
|
||||
cache_history, metadata = load_kv_cache_from_file(args.kv_cache_file)
|
||||
|
||||
# Building tokenizer_config
|
||||
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
||||
tokenizer_config = (
|
||||
{} if cache_history is None else json.loads(metadata["tokenizer_config"])
|
||||
)
|
||||
if args.trust_remote_code:
|
||||
tokenizer_config["trust_remote_code"] = True
|
||||
if args.eos_token is not None:
|
||||
tokenizer_config["eos_token"] = args.eos_token
|
||||
|
||||
# If no model path is provided then use the one in the kv cache history
|
||||
model_path = args.model
|
||||
if cache_history is not None and model_path is None:
|
||||
model_path = metadata["model"]
|
||||
|
||||
model, tokenizer = load(
|
||||
args.model,
|
||||
model_path,
|
||||
adapter_path=args.adapter_path,
|
||||
tokenizer_config=tokenizer_config,
|
||||
)
|
||||
@ -136,6 +171,8 @@ def main():
|
||||
if args.use_default_chat_template:
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
elif tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = metadata["chat_template"]
|
||||
|
||||
if not args.ignore_chat_template and (
|
||||
hasattr(tokenizer, "apply_chat_template")
|
||||
@ -145,11 +182,30 @@ def main():
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Treat the prompt as a suffix assuming that the prefix is in the
|
||||
# stored kv cache.
|
||||
if cache_history is not None:
|
||||
test_prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": "<query>"}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
prompt = prompt[test_prompt.index("<query>") :]
|
||||
else:
|
||||
prompt = args.prompt
|
||||
|
||||
formatter = colorprint_by_t0 if args.colorize else None
|
||||
|
||||
# Determine the max kv size from the kv cache or passed arguments
|
||||
max_kv_size = args.max_kv_size
|
||||
if max_kv_size is None:
|
||||
max_kv_size = (
|
||||
int(metadata["max_kv_size"])
|
||||
if cache_history is not None
|
||||
else DEFAULT_MAX_KV_SIZE
|
||||
)
|
||||
|
||||
generate(
|
||||
model,
|
||||
tokenizer,
|
||||
@ -159,7 +215,8 @@ def main():
|
||||
formatter=formatter,
|
||||
temp=args.temp,
|
||||
top_p=args.top_p,
|
||||
max_kv_size=args.max_kv_size,
|
||||
max_kv_size=max_kv_size,
|
||||
cache_history=cache_history,
|
||||
)
|
||||
|
||||
|
||||
|
@ -46,6 +46,7 @@ class KVCache:
|
||||
self.values[..., prev : self.offset, :] = values
|
||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self.keys, self.values
|
||||
|
||||
@ -137,6 +138,7 @@ class RotatingKVCache:
|
||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||
return self.keys, self.values
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self.keys, self.values
|
||||
|
||||
|
@ -9,7 +9,7 @@ import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -126,6 +126,26 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f
|
||||
return logits
|
||||
|
||||
|
||||
def make_kv_caches(
|
||||
model: nn.Module, max_kv_size: Optional[int] = None
|
||||
) -> List[Union[KVCache, RotatingKVCache]]:
|
||||
if hasattr(model, "make_cache"):
|
||||
return model.make_cache()
|
||||
|
||||
kv_heads = (
|
||||
[model.n_kv_heads] * len(model.layers)
|
||||
if isinstance(model.n_kv_heads, int)
|
||||
else model.n_kv_heads
|
||||
)
|
||||
if max_kv_size is not None:
|
||||
return [
|
||||
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
|
||||
for n in kv_heads
|
||||
]
|
||||
else:
|
||||
return [KVCache(model.head_dim, n) for n in kv_heads]
|
||||
|
||||
|
||||
def generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
@ -138,6 +158,7 @@ def generate_step(
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
prefill_step_size: int = 512,
|
||||
max_kv_size: Optional[int] = None,
|
||||
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing token ids based on the given prompt from the model.
|
||||
@ -194,21 +215,19 @@ def generate_step(
|
||||
)
|
||||
|
||||
y = prompt
|
||||
if hasattr(model, "make_cache"):
|
||||
cache = model.make_cache()
|
||||
else:
|
||||
kv_heads = (
|
||||
[model.n_kv_heads] * len(model.layers)
|
||||
if isinstance(model.n_kv_heads, int)
|
||||
else model.n_kv_heads
|
||||
)
|
||||
if max_kv_size is not None:
|
||||
cache = [
|
||||
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
|
||||
for n in kv_heads
|
||||
]
|
||||
else:
|
||||
cache = [KVCache(model.head_dim, n) for n in kv_heads]
|
||||
|
||||
# Create the KV cache for generation
|
||||
cache = make_kv_caches(model, max_kv_size)
|
||||
|
||||
if cache_history is not None:
|
||||
if len(cache_history) != len(cache):
|
||||
raise ValueError("Wrong number of layers in the cache history")
|
||||
|
||||
# Set the history in the cache objects and evaluate them to prepare for
|
||||
# generation.
|
||||
for c, h in zip(cache, cache_history):
|
||||
c.update_and_fetch(h[0], h[1])
|
||||
mx.eval([c.state for c in cache])
|
||||
|
||||
repetition_context = prompt.tolist()
|
||||
|
||||
|
@ -31,6 +31,7 @@ setup(
|
||||
},
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"mlx_lm.cache_prompt = mlx_lm.cache_prompt:main",
|
||||
"mlx_lm.convert = mlx_lm.convert:main",
|
||||
"mlx_lm.fuse = mlx_lm.fuse:main",
|
||||
"mlx_lm.generate = mlx_lm.generate:main",
|
||||
|
Loading…
Reference in New Issue
Block a user