add QuantizedKVCache

This commit is contained in:
Alex Barron 2024-10-26 00:23:46 -07:00
parent 9f34fdbda4
commit 48655a7f83
4 changed files with 154 additions and 10 deletions

View File

@ -107,6 +107,23 @@ def setup_arg_parser():
default=None,
help="A file containing saved KV caches to avoid recomputing them",
)
parser.add_argument(
"--quantized-kv",
help="Whether to quantize the KV cache.",
action="store_true",
)
parser.add_argument(
"--kv-group-size",
type=int,
help="Group size for kv cache quantization.",
default=64,
)
parser.add_argument(
"--kv-bits",
type=int,
help="Number of bits for kv cache quantization.",
default=8,
)
return parser
@ -150,7 +167,11 @@ def main():
using_cache = args.prompt_cache_file is not None
if using_cache:
prompt_cache, metadata = load_prompt_cache(
args.prompt_cache_file, return_metadata=True
args.prompt_cache_file,
return_metadata=True,
quantized_kv=args.quantized_kv,
kv_group_size=args.kv_group_size,
kv_bits=args.kv_bits,
)
# Building tokenizer_config
@ -227,6 +248,9 @@ def main():
top_p=args.top_p,
max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None,
quantized_kv=args.quantized_kv,
kv_group_size=args.kv_group_size,
kv_bits=args.kv_bits,
)
if not args.verbose:
print(response)

View File

@ -7,7 +7,13 @@ import mlx.nn as nn
from mlx.utils import tree_flatten, tree_unflatten
def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]:
def make_prompt_cache(
model: nn.Module,
max_kv_size: Optional[int] = None,
quantized_kv: bool = False,
kv_group_size: int = 64,
kv_bits: int = 8,
) -> List[Any]:
"""
Construct the model's cache for use when cgeneration.
@ -24,7 +30,12 @@ def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> Li
return model.make_cache()
num_layers = len(model.layers)
if max_kv_size is not None:
if quantized_kv:
return [
QuantizedKVCache(group_size=kv_group_size, bits=kv_bits)
for _ in range(num_layers)
]
elif max_kv_size is not None:
return [
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
]
@ -51,7 +62,9 @@ def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str]
mx.save_safetensors(file_name, cache_data, cache_metadata)
def load_prompt_cache(file_name, return_metadata=False):
def load_prompt_cache(
file_name, return_metadata=False, quantized_kv=False, kv_group_size=64, kv_bits=8
):
"""
Load a prompt cache from a file.
@ -72,6 +85,8 @@ def load_prompt_cache(file_name, return_metadata=False):
for c, state, meta_state in zip(cache, arrays, info):
c.state = state
c.meta_state = meta_state
if quantized_kv:
cache = [c.to_quantized(group_size=kv_group_size, bits=kv_bits) for c in cache]
if return_metadata:
return cache, metadata
return cache
@ -126,6 +141,86 @@ class _BaseCache:
return False
class QuantizedKVCache(_BaseCache):
def __init__(self, group_size: int = 64, bits: int = 4):
self.keys = None
self.values = None
self.offset = 0
self.step = 256
self.group_size = group_size
self.bits = bits
def update_and_fetch(self, keys, values):
B, n_kv_heads, num_steps, k_head_dim = keys.shape
prev = self.offset
if self.keys is None or (prev + num_steps) > self.keys[0].shape[2]:
el_per_int = 8 * mx.uint32.size // self.bits
n_steps = (self.step + keys[0].shape[2] - 1) // self.step
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim // el_per_int)
scales_dim = k_head_dim // self.group_size
k_scale_shape = k_shape[:-1] + (scales_dim,)
v_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim // el_per_int)
scale_bias_init = lambda: mx.zeros(k_scale_shape, keys.dtype)
new_k = (mx.zeros(k_shape, mx.uint32), scale_bias_init(), scale_bias_init())
new_v = (mx.zeros(v_shape, mx.uint32), scale_bias_init(), scale_bias_init())
if self.keys is not None:
if prev % self.step != 0:
self.keys = tuple(x[..., :prev, :] for x in self.keys)
self.values = tuple(x[..., :prev, :] for x in self.values)
self.keys = tuple(
mx.concatenate([self.keys[i], new_k[i]], axis=2) for i in range(3)
)
self.values = tuple(
mx.concatenate([self.values[i], new_v[i]], axis=2) for i in range(3)
)
else:
self.keys, self.values = new_k, new_v
self.offset += num_steps
if num_steps > 1:
keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits)
values = mx.quantize(values, group_size=self.group_size, bits=self.bits)
for i in range(len(self.keys)):
self.keys[i][..., prev : self.offset, :] = keys[i]
self.values[i][..., prev : self.offset, :] = values[i]
else:
outputs = mx.fast.quantized_kv_update(
keys,
values,
*self.keys,
*self.values,
prev,
group_size=self.group_size,
bits=self.bits
)
self.keys = outputs[:3]
self.values = outputs[3:]
return (
tuple(x[..., : self.offset, :] for x in self.keys),
tuple(x[..., : self.offset, :] for x in self.values),
)
def state(self):
return self.keys, self.values
@classmethod
def from_cache(
cls, cache: _BaseCache, group_size: int = 64, bits: int = 4
) -> "QuantizedKVCache":
quant_cache = cls(group_size=group_size, bits=bits)
quant_cache.offset = cache.offset
quant_cache.keys = mx.quantize(cache.keys, group_size=group_size, bits=bits)
quant_cache.values = mx.quantize(cache.values, group_size=group_size, bits=bits)
return quant_cache
class KVCache(_BaseCache):
def __init__(self):
self.keys = None
@ -180,6 +275,9 @@ class KVCache(_BaseCache):
self.offset -= n
return n
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
return QuantizedKVCache.from_cache(self, group_size=group_size, bits=bits)
class RotatingKVCache(_BaseCache):

View File

@ -1,12 +1,13 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask
from .cache import QuantizedKVCache
@dataclass
@ -190,9 +191,21 @@ class Attention(nn.Module):
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
if isinstance(cache, QuantizedKVCache):
output = mx.fast.quantized_scaled_dot_product_attention(
queries,
*keys,
*values,
scale=self.scale,
mask=mask,
group_size=cache.group_size,
bits=cache.bits,
)
else:
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)

View File

@ -19,7 +19,7 @@ from mlx.utils import tree_flatten, tree_reduce
from transformers import PreTrainedTokenizer
# Local imports
from .models import base, cache
from .models import cache
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import dequantize as dequantize_model
@ -173,6 +173,9 @@ def generate_step(
prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None,
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
quantized_kv: bool = False,
kv_group_size: int = 64,
kv_bits: int = 8,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@ -255,7 +258,13 @@ def generate_step(
# Create the KV cache for generation
if prompt_cache is None:
prompt_cache = cache.make_prompt_cache(model, max_kv_size)
prompt_cache = cache.make_prompt_cache(
model,
max_kv_size=max_kv_size,
quantized_kv=quantized_kv,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.")