mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Quantized KV Cache (#1075)
* add QuantizedKVCache * simplify * add tests * single sdpa function * fix sed * in place * fix tests * support different k and v head dims
This commit is contained in:
@@ -6,7 +6,7 @@ import sys
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .models.cache import load_prompt_cache
|
||||
from .models.cache import QuantizedKVCache, load_prompt_cache
|
||||
from .utils import generate, load
|
||||
|
||||
DEFAULT_PROMPT = "hello"
|
||||
@@ -15,6 +15,7 @@ DEFAULT_TEMP = 0.0
|
||||
DEFAULT_TOP_P = 1.0
|
||||
DEFAULT_SEED = 0
|
||||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
DEFAULT_QUANTIZED_KV_START = 5000
|
||||
|
||||
|
||||
def str2bool(string):
|
||||
@@ -107,6 +108,26 @@ def setup_arg_parser():
|
||||
default=None,
|
||||
help="A file containing saved KV caches to avoid recomputing them",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-bits",
|
||||
type=int,
|
||||
help="Number of bits for KV cache quantization. "
|
||||
"Defaults to no quantization.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-group-size",
|
||||
type=int,
|
||||
help="Group size for KV cache quantization.",
|
||||
default=64,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantized-kv-start",
|
||||
help="When --kv-bits is set, start quantizing the KV cache "
|
||||
"from this step onwards.",
|
||||
type=int,
|
||||
default=DEFAULT_QUANTIZED_KV_START,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@@ -150,8 +171,18 @@ def main():
|
||||
using_cache = args.prompt_cache_file is not None
|
||||
if using_cache:
|
||||
prompt_cache, metadata = load_prompt_cache(
|
||||
args.prompt_cache_file, return_metadata=True
|
||||
args.prompt_cache_file,
|
||||
return_metadata=True,
|
||||
)
|
||||
if isinstance(prompt_cache[0], QuantizedKVCache):
|
||||
if args.kv_bits is not None and args.kv_bits != prompt_cache[0].bits:
|
||||
raise ValueError(
|
||||
"--kv-bits does not match the kv cache loaded from --prompt-cache-file."
|
||||
)
|
||||
if args.kv_group_size != prompt_cache[0].group_size:
|
||||
raise ValueError(
|
||||
"--kv-group-size does not match the kv cache loaded from --prompt-cache-file."
|
||||
)
|
||||
|
||||
# Building tokenizer_config
|
||||
tokenizer_config = (
|
||||
@@ -227,6 +258,9 @@ def main():
|
||||
top_p=args.top_p,
|
||||
max_kv_size=args.max_kv_size,
|
||||
prompt_cache=prompt_cache if using_cache else None,
|
||||
kv_bits=args.kv_bits,
|
||||
kv_group_size=args.kv_group_size,
|
||||
quantized_kv_start=args.quantized_kv_start,
|
||||
)
|
||||
if not args.verbose:
|
||||
print(response)
|
||||
|
Reference in New Issue
Block a user