mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Add the ability to load the KV cache from a file (#956)
This commit is contained in:
parent
7f8c961287
commit
1003a8b2dd
149
llms/mlx_lm/cache_prompt.py
Normal file
149
llms/mlx_lm/cache_prompt.py
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
from .utils import load, make_kv_caches
|
||||||
|
|
||||||
|
|
||||||
|
def setup_arg_parser():
|
||||||
|
"""Set up and return the argument parser."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Cache the KV cache of a prompt to be reused with mlx_lm.generate"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default="mlx_model",
|
||||||
|
help="The path to the local model directory or Hugging Face repo.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adapter-path",
|
||||||
|
type=str,
|
||||||
|
help="Optional path for the trained adapter weights and config.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--trust-remote-code",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable trusting remote code for tokenizer",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--eos-token",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="End of sequence token for tokenizer",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ignore-chat-template",
|
||||||
|
action="store_true",
|
||||||
|
help="Use the raw prompt without the tokenizer's chat template.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-default-chat-template",
|
||||||
|
action="store_true",
|
||||||
|
help="Use the default chat template",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache-limit-gb",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Set the MLX cache limit in GB",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-kv-size",
|
||||||
|
type=int,
|
||||||
|
default=1024,
|
||||||
|
help="Set the maximum key-value cache size",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-cache-file", help="The file to save the KV caches in", required=True
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt",
|
||||||
|
required=True,
|
||||||
|
help="Message to be processed by the model ('-' reads from stdin)",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = setup_arg_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.cache_limit_gb is not None:
|
||||||
|
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
|
||||||
|
|
||||||
|
# Building tokenizer_config
|
||||||
|
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
||||||
|
if args.eos_token is not None:
|
||||||
|
tokenizer_config["eos_token"] = args.eos_token
|
||||||
|
|
||||||
|
model, tokenizer = load(
|
||||||
|
args.model,
|
||||||
|
adapter_path=args.adapter_path,
|
||||||
|
tokenizer_config=tokenizer_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt
|
||||||
|
|
||||||
|
if args.use_default_chat_template:
|
||||||
|
if tokenizer.chat_template is None:
|
||||||
|
tokenizer.chat_template = tokenizer.default_chat_template
|
||||||
|
|
||||||
|
if not args.ignore_chat_template and (
|
||||||
|
hasattr(tokenizer, "apply_chat_template")
|
||||||
|
and tokenizer.chat_template is not None
|
||||||
|
):
|
||||||
|
messages = [{"role": "user", "content": args.prompt}]
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Treat the prompt as a prefix assuming that the suffix will be
|
||||||
|
# provided at generation time.
|
||||||
|
test_prompt = tokenizer.apply_chat_template(
|
||||||
|
[{"role": "user", "content": "<query>"}],
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
n = len(test_prompt) - test_prompt.index("<query>") - len("<query>")
|
||||||
|
prompt = prompt[:-n]
|
||||||
|
else:
|
||||||
|
prompt = args.prompt
|
||||||
|
|
||||||
|
cache = make_kv_caches(model, args.max_kv_size)
|
||||||
|
y = mx.array(tokenizer.encode(prompt))
|
||||||
|
|
||||||
|
# Process the prompt
|
||||||
|
processed = 0
|
||||||
|
step_size = 512
|
||||||
|
start = time.time()
|
||||||
|
max_msg_len = 0
|
||||||
|
while y.size > 0:
|
||||||
|
model(y[:step_size][None], cache=cache)
|
||||||
|
mx.eval([c.state for c in cache])
|
||||||
|
processed += min(y.size, step_size)
|
||||||
|
y = y[step_size:]
|
||||||
|
current = time.time()
|
||||||
|
speed = processed / (current - start)
|
||||||
|
msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
|
||||||
|
max_msg_len = max(max_msg_len, len(msg))
|
||||||
|
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)
|
||||||
|
print()
|
||||||
|
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
|
||||||
|
|
||||||
|
print("Saving...")
|
||||||
|
cache_dict = {}
|
||||||
|
for i, c in enumerate(cache):
|
||||||
|
cache_dict[f"{i}_keys"] = c.state[0]
|
||||||
|
cache_dict[f"{i}_values"] = c.state[1]
|
||||||
|
metadata = {}
|
||||||
|
metadata["model"] = args.model
|
||||||
|
metadata["chat_template"] = tokenizer.chat_template
|
||||||
|
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
||||||
|
metadata["max_kv_size"] = str(args.max_kv_size)
|
||||||
|
mx.save_safetensors(args.kv_cache_file, cache_dict, metadata)
|
@ -1,17 +1,18 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
from .utils import generate, load
|
from .utils import generate, load
|
||||||
|
|
||||||
DEFAULT_MODEL_PATH = "mlx_model"
|
|
||||||
DEFAULT_PROMPT = "hello"
|
DEFAULT_PROMPT = "hello"
|
||||||
DEFAULT_MAX_TOKENS = 100
|
DEFAULT_MAX_TOKENS = 100
|
||||||
DEFAULT_TEMP = 0.6
|
DEFAULT_TEMP = 0.6
|
||||||
DEFAULT_TOP_P = 1.0
|
DEFAULT_TOP_P = 1.0
|
||||||
DEFAULT_SEED = 0
|
DEFAULT_SEED = 0
|
||||||
|
DEFAULT_MAX_KV_SIZE = 1024
|
||||||
|
|
||||||
|
|
||||||
def setup_arg_parser():
|
def setup_arg_parser():
|
||||||
@ -20,7 +21,6 @@ def setup_arg_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
type=str,
|
type=str,
|
||||||
default="mlx_model",
|
|
||||||
help="The path to the local model directory or Hugging Face repo.",
|
help="The path to the local model directory or Hugging Face repo.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -80,9 +80,14 @@ def setup_arg_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-kv-size",
|
"--max-kv-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=1024,
|
|
||||||
help="Set the maximum key-value cache size",
|
help="Set the maximum key-value cache size",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-cache-file",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="A file containing saved KV caches to avoid recomputing them",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -113,6 +118,24 @@ def colorprint_by_t0(s, t0):
|
|||||||
colorprint(color, s)
|
colorprint(color, s)
|
||||||
|
|
||||||
|
|
||||||
|
def load_kv_cache_from_file(kv_cache_file):
|
||||||
|
if kv_cache_file is None:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
kv_cache, metadata = mx.load(kv_cache_file, return_metadata=True)
|
||||||
|
cache_per_layer = {}
|
||||||
|
for k, x in kv_cache.items():
|
||||||
|
layer, kv_type = k.split("_")
|
||||||
|
if layer not in cache_per_layer:
|
||||||
|
cache_per_layer[layer] = {}
|
||||||
|
cache_per_layer[layer][kv_type] = x
|
||||||
|
|
||||||
|
cache_history = [None] * len(cache_per_layer)
|
||||||
|
for layer, c in cache_per_layer.items():
|
||||||
|
cache_history[int(layer)] = (c["keys"], c["values"])
|
||||||
|
return cache_history, metadata
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = setup_arg_parser()
|
parser = setup_arg_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -122,13 +145,25 @@ def main():
|
|||||||
if args.cache_limit_gb is not None:
|
if args.cache_limit_gb is not None:
|
||||||
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
|
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
|
||||||
|
|
||||||
|
# Load the kv cache and metadata if a kv cache file is provided
|
||||||
|
cache_history, metadata = load_kv_cache_from_file(args.kv_cache_file)
|
||||||
|
|
||||||
# Building tokenizer_config
|
# Building tokenizer_config
|
||||||
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
tokenizer_config = (
|
||||||
|
{} if cache_history is None else json.loads(metadata["tokenizer_config"])
|
||||||
|
)
|
||||||
|
if args.trust_remote_code:
|
||||||
|
tokenizer_config["trust_remote_code"] = True
|
||||||
if args.eos_token is not None:
|
if args.eos_token is not None:
|
||||||
tokenizer_config["eos_token"] = args.eos_token
|
tokenizer_config["eos_token"] = args.eos_token
|
||||||
|
|
||||||
|
# If no model path is provided then use the one in the kv cache history
|
||||||
|
model_path = args.model
|
||||||
|
if cache_history is not None and model_path is None:
|
||||||
|
model_path = metadata["model"]
|
||||||
|
|
||||||
model, tokenizer = load(
|
model, tokenizer = load(
|
||||||
args.model,
|
model_path,
|
||||||
adapter_path=args.adapter_path,
|
adapter_path=args.adapter_path,
|
||||||
tokenizer_config=tokenizer_config,
|
tokenizer_config=tokenizer_config,
|
||||||
)
|
)
|
||||||
@ -136,6 +171,8 @@ def main():
|
|||||||
if args.use_default_chat_template:
|
if args.use_default_chat_template:
|
||||||
if tokenizer.chat_template is None:
|
if tokenizer.chat_template is None:
|
||||||
tokenizer.chat_template = tokenizer.default_chat_template
|
tokenizer.chat_template = tokenizer.default_chat_template
|
||||||
|
elif tokenizer.chat_template is None:
|
||||||
|
tokenizer.chat_template = metadata["chat_template"]
|
||||||
|
|
||||||
if not args.ignore_chat_template and (
|
if not args.ignore_chat_template and (
|
||||||
hasattr(tokenizer, "apply_chat_template")
|
hasattr(tokenizer, "apply_chat_template")
|
||||||
@ -145,11 +182,30 @@ def main():
|
|||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Treat the prompt as a suffix assuming that the prefix is in the
|
||||||
|
# stored kv cache.
|
||||||
|
if cache_history is not None:
|
||||||
|
test_prompt = tokenizer.apply_chat_template(
|
||||||
|
[{"role": "user", "content": "<query>"}],
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
prompt = prompt[test_prompt.index("<query>") :]
|
||||||
else:
|
else:
|
||||||
prompt = args.prompt
|
prompt = args.prompt
|
||||||
|
|
||||||
formatter = colorprint_by_t0 if args.colorize else None
|
formatter = colorprint_by_t0 if args.colorize else None
|
||||||
|
|
||||||
|
# Determine the max kv size from the kv cache or passed arguments
|
||||||
|
max_kv_size = args.max_kv_size
|
||||||
|
if max_kv_size is None:
|
||||||
|
max_kv_size = (
|
||||||
|
int(metadata["max_kv_size"])
|
||||||
|
if cache_history is not None
|
||||||
|
else DEFAULT_MAX_KV_SIZE
|
||||||
|
)
|
||||||
|
|
||||||
generate(
|
generate(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@ -159,7 +215,8 @@ def main():
|
|||||||
formatter=formatter,
|
formatter=formatter,
|
||||||
temp=args.temp,
|
temp=args.temp,
|
||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
max_kv_size=args.max_kv_size,
|
max_kv_size=max_kv_size,
|
||||||
|
cache_history=cache_history,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,6 +46,7 @@ class KVCache:
|
|||||||
self.values[..., prev : self.offset, :] = values
|
self.values[..., prev : self.offset, :] = values
|
||||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||||
|
|
||||||
|
@property
|
||||||
def state(self):
|
def state(self):
|
||||||
return self.keys, self.values
|
return self.keys, self.values
|
||||||
|
|
||||||
@ -137,6 +138,7 @@ class RotatingKVCache:
|
|||||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||||
return self.keys, self.values
|
return self.keys, self.values
|
||||||
|
|
||||||
|
@property
|
||||||
def state(self):
|
def state(self):
|
||||||
return self.keys, self.values
|
return self.keys, self.values
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ import shutil
|
|||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Type, Union
|
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -126,6 +126,26 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def make_kv_caches(
|
||||||
|
model: nn.Module, max_kv_size: Optional[int] = None
|
||||||
|
) -> List[Union[KVCache, RotatingKVCache]]:
|
||||||
|
if hasattr(model, "make_cache"):
|
||||||
|
return model.make_cache()
|
||||||
|
|
||||||
|
kv_heads = (
|
||||||
|
[model.n_kv_heads] * len(model.layers)
|
||||||
|
if isinstance(model.n_kv_heads, int)
|
||||||
|
else model.n_kv_heads
|
||||||
|
)
|
||||||
|
if max_kv_size is not None:
|
||||||
|
return [
|
||||||
|
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
|
||||||
|
for n in kv_heads
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return [KVCache(model.head_dim, n) for n in kv_heads]
|
||||||
|
|
||||||
|
|
||||||
def generate_step(
|
def generate_step(
|
||||||
prompt: mx.array,
|
prompt: mx.array,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -138,6 +158,7 @@ def generate_step(
|
|||||||
logit_bias: Optional[Dict[int, float]] = None,
|
logit_bias: Optional[Dict[int, float]] = None,
|
||||||
prefill_step_size: int = 512,
|
prefill_step_size: int = 512,
|
||||||
max_kv_size: Optional[int] = None,
|
max_kv_size: Optional[int] = None,
|
||||||
|
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing token ids based on the given prompt from the model.
|
A generator producing token ids based on the given prompt from the model.
|
||||||
@ -194,21 +215,19 @@ def generate_step(
|
|||||||
)
|
)
|
||||||
|
|
||||||
y = prompt
|
y = prompt
|
||||||
if hasattr(model, "make_cache"):
|
|
||||||
cache = model.make_cache()
|
# Create the KV cache for generation
|
||||||
else:
|
cache = make_kv_caches(model, max_kv_size)
|
||||||
kv_heads = (
|
|
||||||
[model.n_kv_heads] * len(model.layers)
|
if cache_history is not None:
|
||||||
if isinstance(model.n_kv_heads, int)
|
if len(cache_history) != len(cache):
|
||||||
else model.n_kv_heads
|
raise ValueError("Wrong number of layers in the cache history")
|
||||||
)
|
|
||||||
if max_kv_size is not None:
|
# Set the history in the cache objects and evaluate them to prepare for
|
||||||
cache = [
|
# generation.
|
||||||
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
|
for c, h in zip(cache, cache_history):
|
||||||
for n in kv_heads
|
c.update_and_fetch(h[0], h[1])
|
||||||
]
|
mx.eval([c.state for c in cache])
|
||||||
else:
|
|
||||||
cache = [KVCache(model.head_dim, n) for n in kv_heads]
|
|
||||||
|
|
||||||
repetition_context = prompt.tolist()
|
repetition_context = prompt.tolist()
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ setup(
|
|||||||
},
|
},
|
||||||
entry_points={
|
entry_points={
|
||||||
"console_scripts": [
|
"console_scripts": [
|
||||||
|
"mlx_lm.cache_prompt = mlx_lm.cache_prompt:main",
|
||||||
"mlx_lm.convert = mlx_lm.convert:main",
|
"mlx_lm.convert = mlx_lm.convert:main",
|
||||||
"mlx_lm.fuse = mlx_lm.fuse:main",
|
"mlx_lm.fuse = mlx_lm.fuse:main",
|
||||||
"mlx_lm.generate = mlx_lm.generate:main",
|
"mlx_lm.generate = mlx_lm.generate:main",
|
||||||
|
Loading…
Reference in New Issue
Block a user