mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-19 11:28:07 +08:00
reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat
This commit is contained in:
@@ -6,6 +6,7 @@ import sys
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .models.cache import load_prompt_cache
|
||||
from .utils import generate, load
|
||||
|
||||
DEFAULT_PROMPT = "hello"
|
||||
@@ -96,7 +97,7 @@ def setup_arg_parser():
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-cache-file",
|
||||
"--prompt-cache-file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A file containing saved KV caches to avoid recomputing them",
|
||||
@@ -131,24 +132,6 @@ def colorprint_by_t0(s, t0):
|
||||
colorprint(color, s)
|
||||
|
||||
|
||||
def load_kv_cache_from_file(kv_cache_file):
|
||||
if kv_cache_file is None:
|
||||
return None, None
|
||||
|
||||
kv_cache, metadata = mx.load(kv_cache_file, return_metadata=True)
|
||||
cache_per_layer = {}
|
||||
for k, x in kv_cache.items():
|
||||
layer, kv_type = k.split("_")
|
||||
if layer not in cache_per_layer:
|
||||
cache_per_layer[layer] = {}
|
||||
cache_per_layer[layer][kv_type] = x
|
||||
|
||||
cache_history = [None] * len(cache_per_layer)
|
||||
for layer, c in cache_per_layer.items():
|
||||
cache_history[int(layer)] = (c["keys"], c["values"])
|
||||
return cache_history, metadata
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_arg_parser()
|
||||
args = parser.parse_args()
|
||||
@@ -158,22 +141,32 @@ def main():
|
||||
if args.cache_limit_gb is not None:
|
||||
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
|
||||
|
||||
# Load the kv cache and metadata if a kv cache file is provided
|
||||
cache_history, metadata = load_kv_cache_from_file(args.kv_cache_file)
|
||||
# Load the prompt cache and metadata if a cache file is provided
|
||||
using_cache = args.prompt_cache_file is not None
|
||||
if using_cache:
|
||||
prompt_cache, metadata = load_prompt_cache(
|
||||
args.prompt_cache_file, return_metadata=True
|
||||
)
|
||||
|
||||
# Building tokenizer_config
|
||||
tokenizer_config = (
|
||||
{} if cache_history is None else json.loads(metadata["tokenizer_config"])
|
||||
{} if not using_cache else json.loads(metadata["tokenizer_config"])
|
||||
)
|
||||
if args.trust_remote_code:
|
||||
tokenizer_config["trust_remote_code"] = True
|
||||
if args.eos_token is not None:
|
||||
tokenizer_config["eos_token"] = args.eos_token
|
||||
|
||||
# If no model path is provided then use the one in the kv cache history
|
||||
model_path = args.model
|
||||
if cache_history is not None and model_path is None:
|
||||
model_path = metadata["model"]
|
||||
if using_cache:
|
||||
if model_path is None:
|
||||
model_path = metadata["model"]
|
||||
elif model_path != metadata["model"]:
|
||||
raise ValueError(
|
||||
f"Providing a different model ({model_path}) than that "
|
||||
f"used to create the prompt cache ({metadata['model']}) "
|
||||
"is an error."
|
||||
)
|
||||
|
||||
model, tokenizer = load(
|
||||
model_path,
|
||||
@@ -184,7 +177,7 @@ def main():
|
||||
if args.use_default_chat_template:
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
elif cache_history is not None:
|
||||
elif using_cache:
|
||||
tokenizer.chat_template = metadata["chat_template"]
|
||||
|
||||
if not args.ignore_chat_template and (
|
||||
@@ -203,7 +196,7 @@ def main():
|
||||
|
||||
# Treat the prompt as a suffix assuming that the prefix is in the
|
||||
# stored kv cache.
|
||||
if cache_history is not None:
|
||||
if using_cache:
|
||||
test_prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": "<query>"}],
|
||||
tokenize=False,
|
||||
@@ -217,12 +210,6 @@ def main():
|
||||
raise ValueError("Cannot use --colorize with --verbose=False")
|
||||
formatter = colorprint_by_t0 if args.colorize else None
|
||||
|
||||
# Determine the max kv size from the kv cache or passed arguments
|
||||
max_kv_size = args.max_kv_size
|
||||
if cache_history is not None:
|
||||
max_kv_size = metadata["max_kv_size"]
|
||||
max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None
|
||||
|
||||
response = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
@@ -232,8 +219,8 @@ def main():
|
||||
formatter=formatter,
|
||||
temp=args.temp,
|
||||
top_p=args.top_p,
|
||||
max_kv_size=max_kv_size,
|
||||
cache_history=cache_history,
|
||||
max_kv_size=args.max_kv_size,
|
||||
prompt_cache=prompt_cache if using_cache else None,
|
||||
)
|
||||
if not args.verbose:
|
||||
print(response)
|
||||
|
Reference in New Issue
Block a user