From 1003a8b2dd59a22255a6a6c9a20f9d41f9812fb5 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 28 Aug 2024 22:11:45 -0700 Subject: [PATCH] Add the ability to load the KV cache from a file (#956) --- llms/mlx_lm/cache_prompt.py | 149 ++++++++++++++++++++++++++++++++++++ llms/mlx_lm/generate.py | 69 +++++++++++++++-- llms/mlx_lm/models/base.py | 2 + llms/mlx_lm/utils.py | 51 ++++++++---- llms/setup.py | 1 + 5 files changed, 250 insertions(+), 22 deletions(-) create mode 100644 llms/mlx_lm/cache_prompt.py diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py new file mode 100644 index 00000000..ad045f1a --- /dev/null +++ b/llms/mlx_lm/cache_prompt.py @@ -0,0 +1,149 @@ +# Copyright © 2024 Apple Inc. + +import argparse +import json +import sys +import time + +import mlx.core as mx + +from .utils import load, make_kv_caches + + +def setup_arg_parser(): + """Set up and return the argument parser.""" + parser = argparse.ArgumentParser( + description="Cache the KV cache of a prompt to be reused with mlx_lm.generate" + ) + parser.add_argument( + "--model", + type=str, + default="mlx_model", + help="The path to the local model directory or Hugging Face repo.", + ) + parser.add_argument( + "--adapter-path", + type=str, + help="Optional path for the trained adapter weights and config.", + ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Enable trusting remote code for tokenizer", + ) + parser.add_argument( + "--eos-token", + type=str, + default=None, + help="End of sequence token for tokenizer", + ) + parser.add_argument( + "--ignore-chat-template", + action="store_true", + help="Use the raw prompt without the tokenizer's chat template.", + ) + parser.add_argument( + "--use-default-chat-template", + action="store_true", + help="Use the default chat template", + ) + parser.add_argument( + "--cache-limit-gb", + type=int, + default=None, + help="Set the MLX cache limit in GB", + ) + parser.add_argument( + "--max-kv-size", + type=int, + default=1024, + help="Set the maximum key-value cache size", + ) + parser.add_argument( + "--kv-cache-file", help="The file to save the KV caches in", required=True + ) + parser.add_argument( + "--prompt", + required=True, + help="Message to be processed by the model ('-' reads from stdin)", + ) + return parser + + +def main(): + parser = setup_arg_parser() + args = parser.parse_args() + + if args.cache_limit_gb is not None: + mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) + + # Building tokenizer_config + tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} + if args.eos_token is not None: + tokenizer_config["eos_token"] = args.eos_token + + model, tokenizer = load( + args.model, + adapter_path=args.adapter_path, + tokenizer_config=tokenizer_config, + ) + + args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt + + if args.use_default_chat_template: + if tokenizer.chat_template is None: + tokenizer.chat_template = tokenizer.default_chat_template + + if not args.ignore_chat_template and ( + hasattr(tokenizer, "apply_chat_template") + and tokenizer.chat_template is not None + ): + messages = [{"role": "user", "content": args.prompt}] + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + # Treat the prompt as a prefix assuming that the suffix will be + # provided at generation time. + test_prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": ""}], + tokenize=False, + add_generation_prompt=True, + ) + n = len(test_prompt) - test_prompt.index("") - len("") + prompt = prompt[:-n] + else: + prompt = args.prompt + + cache = make_kv_caches(model, args.max_kv_size) + y = mx.array(tokenizer.encode(prompt)) + + # Process the prompt + processed = 0 + step_size = 512 + start = time.time() + max_msg_len = 0 + while y.size > 0: + model(y[:step_size][None], cache=cache) + mx.eval([c.state for c in cache]) + processed += min(y.size, step_size) + y = y[step_size:] + current = time.time() + speed = processed / (current - start) + msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)" + max_msg_len = max(max_msg_len, len(msg)) + print(msg + " " * (max_msg_len - len(msg)), end="", flush=True) + print() + print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") + + print("Saving...") + cache_dict = {} + for i, c in enumerate(cache): + cache_dict[f"{i}_keys"] = c.state[0] + cache_dict[f"{i}_values"] = c.state[1] + metadata = {} + metadata["model"] = args.model + metadata["chat_template"] = tokenizer.chat_template + metadata["tokenizer_config"] = json.dumps(tokenizer_config) + metadata["max_kv_size"] = str(args.max_kv_size) + mx.save_safetensors(args.kv_cache_file, cache_dict, metadata) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 6707d25c..4aa4001a 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -1,17 +1,18 @@ # Copyright © 2023-2024 Apple Inc. import argparse +import json import mlx.core as mx from .utils import generate, load -DEFAULT_MODEL_PATH = "mlx_model" DEFAULT_PROMPT = "hello" DEFAULT_MAX_TOKENS = 100 DEFAULT_TEMP = 0.6 DEFAULT_TOP_P = 1.0 DEFAULT_SEED = 0 +DEFAULT_MAX_KV_SIZE = 1024 def setup_arg_parser(): @@ -20,7 +21,6 @@ def setup_arg_parser(): parser.add_argument( "--model", type=str, - default="mlx_model", help="The path to the local model directory or Hugging Face repo.", ) parser.add_argument( @@ -80,9 +80,14 @@ def setup_arg_parser(): parser.add_argument( "--max-kv-size", type=int, - default=1024, help="Set the maximum key-value cache size", ) + parser.add_argument( + "--kv-cache-file", + type=str, + default=None, + help="A file containing saved KV caches to avoid recomputing them", + ) return parser @@ -113,6 +118,24 @@ def colorprint_by_t0(s, t0): colorprint(color, s) +def load_kv_cache_from_file(kv_cache_file): + if kv_cache_file is None: + return None, None + + kv_cache, metadata = mx.load(kv_cache_file, return_metadata=True) + cache_per_layer = {} + for k, x in kv_cache.items(): + layer, kv_type = k.split("_") + if layer not in cache_per_layer: + cache_per_layer[layer] = {} + cache_per_layer[layer][kv_type] = x + + cache_history = [None] * len(cache_per_layer) + for layer, c in cache_per_layer.items(): + cache_history[int(layer)] = (c["keys"], c["values"]) + return cache_history, metadata + + def main(): parser = setup_arg_parser() args = parser.parse_args() @@ -122,13 +145,25 @@ def main(): if args.cache_limit_gb is not None: mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) + # Load the kv cache and metadata if a kv cache file is provided + cache_history, metadata = load_kv_cache_from_file(args.kv_cache_file) + # Building tokenizer_config - tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} + tokenizer_config = ( + {} if cache_history is None else json.loads(metadata["tokenizer_config"]) + ) + if args.trust_remote_code: + tokenizer_config["trust_remote_code"] = True if args.eos_token is not None: tokenizer_config["eos_token"] = args.eos_token + # If no model path is provided then use the one in the kv cache history + model_path = args.model + if cache_history is not None and model_path is None: + model_path = metadata["model"] + model, tokenizer = load( - args.model, + model_path, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config, ) @@ -136,6 +171,8 @@ def main(): if args.use_default_chat_template: if tokenizer.chat_template is None: tokenizer.chat_template = tokenizer.default_chat_template + elif tokenizer.chat_template is None: + tokenizer.chat_template = metadata["chat_template"] if not args.ignore_chat_template and ( hasattr(tokenizer, "apply_chat_template") @@ -145,11 +182,30 @@ def main(): prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) + + # Treat the prompt as a suffix assuming that the prefix is in the + # stored kv cache. + if cache_history is not None: + test_prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": ""}], + tokenize=False, + add_generation_prompt=True, + ) + prompt = prompt[test_prompt.index("") :] else: prompt = args.prompt formatter = colorprint_by_t0 if args.colorize else None + # Determine the max kv size from the kv cache or passed arguments + max_kv_size = args.max_kv_size + if max_kv_size is None: + max_kv_size = ( + int(metadata["max_kv_size"]) + if cache_history is not None + else DEFAULT_MAX_KV_SIZE + ) + generate( model, tokenizer, @@ -159,7 +215,8 @@ def main(): formatter=formatter, temp=args.temp, top_p=args.top_p, - max_kv_size=args.max_kv_size, + max_kv_size=max_kv_size, + cache_history=cache_history, ) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 3e84554c..dc19dd05 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -46,6 +46,7 @@ class KVCache: self.values[..., prev : self.offset, :] = values return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] + @property def state(self): return self.keys, self.values @@ -137,6 +138,7 @@ class RotatingKVCache: return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] return self.keys, self.values + @property def state(self): return self.keys, self.values diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 44196766..71476df3 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -9,7 +9,7 @@ import shutil import time from pathlib import Path from textwrap import dedent -from typing import Any, Callable, Dict, Generator, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union import mlx.core as mx import mlx.nn as nn @@ -126,6 +126,26 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f return logits +def make_kv_caches( + model: nn.Module, max_kv_size: Optional[int] = None +) -> List[Union[KVCache, RotatingKVCache]]: + if hasattr(model, "make_cache"): + return model.make_cache() + + kv_heads = ( + [model.n_kv_heads] * len(model.layers) + if isinstance(model.n_kv_heads, int) + else model.n_kv_heads + ) + if max_kv_size is not None: + return [ + RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4) + for n in kv_heads + ] + else: + return [KVCache(model.head_dim, n) for n in kv_heads] + + def generate_step( prompt: mx.array, model: nn.Module, @@ -138,6 +158,7 @@ def generate_step( logit_bias: Optional[Dict[int, float]] = None, prefill_step_size: int = 512, max_kv_size: Optional[int] = None, + cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -194,21 +215,19 @@ def generate_step( ) y = prompt - if hasattr(model, "make_cache"): - cache = model.make_cache() - else: - kv_heads = ( - [model.n_kv_heads] * len(model.layers) - if isinstance(model.n_kv_heads, int) - else model.n_kv_heads - ) - if max_kv_size is not None: - cache = [ - RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4) - for n in kv_heads - ] - else: - cache = [KVCache(model.head_dim, n) for n in kv_heads] + + # Create the KV cache for generation + cache = make_kv_caches(model, max_kv_size) + + if cache_history is not None: + if len(cache_history) != len(cache): + raise ValueError("Wrong number of layers in the cache history") + + # Set the history in the cache objects and evaluate them to prepare for + # generation. + for c, h in zip(cache, cache_history): + c.update_and_fetch(h[0], h[1]) + mx.eval([c.state for c in cache]) repetition_context = prompt.tolist() diff --git a/llms/setup.py b/llms/setup.py index 88deed17..ac294ae1 100644 --- a/llms/setup.py +++ b/llms/setup.py @@ -31,6 +31,7 @@ setup( }, entry_points={ "console_scripts": [ + "mlx_lm.cache_prompt = mlx_lm.cache_prompt:main", "mlx_lm.convert = mlx_lm.convert:main", "mlx_lm.fuse = mlx_lm.fuse:main", "mlx_lm.generate = mlx_lm.generate:main",