Quantized KV Cache (#1075)

* add QuantizedKVCache

* simplify

* add tests

* single sdpa function

* fix sed

* in place

* fix tests

* support different k and v head dims
This commit is contained in:
Alex Barron
2024-10-31 16:59:52 -07:00
committed by GitHub
parent 9f34fdbda4
commit 85ffd2c96a
32 changed files with 411 additions and 85 deletions

View File

@@ -8,7 +8,9 @@ import time
import mlx.core as mx
from .models.cache import make_prompt_cache, save_prompt_cache
from .utils import load
from .utils import load, maybe_quantize_kv_cache
DEFAULT_QUANTIZED_KV_START = 5000
def setup_arg_parser():
@@ -70,6 +72,26 @@ def setup_arg_parser():
required=True,
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--kv-bits",
type=int,
help="Number of bits for KV cache quantization. "
"Defaults to no quantization.",
default=None,
)
parser.add_argument(
"--kv-group-size",
type=int,
help="Group size for KV cache quantization.",
default=64,
)
parser.add_argument(
"--quantized-kv-start",
help="When --kv-bits is set, start quantizing the KV cache "
"from this step onwards.",
type=int,
default=DEFAULT_QUANTIZED_KV_START,
)
return parser
@@ -127,6 +149,7 @@ def main():
start = time.time()
max_msg_len = 0
while y.size > 0:
model(y[:step_size][None], cache=cache)
mx.eval([c.state for c in cache])
processed += min(y.size, step_size)
@@ -136,6 +159,11 @@ def main():
msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
max_msg_len = max(max_msg_len, len(msg))
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)
maybe_quantize_kv_cache(
cache, args.quantized_kv_start, args.kv_group_size, args.kv_bits
)
print()
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")