add QuantizedKVCache

This commit is contained in:
Alex Barron 2024-10-26 00:23:46 -07:00
parent 9f34fdbda4
commit 48655a7f83
4 changed files with 154 additions and 10 deletions

View File

@ -107,6 +107,23 @@ def setup_arg_parser():
default=None, default=None,
help="A file containing saved KV caches to avoid recomputing them", help="A file containing saved KV caches to avoid recomputing them",
) )
parser.add_argument(
"--quantized-kv",
help="Whether to quantize the KV cache.",
action="store_true",
)
parser.add_argument(
"--kv-group-size",
type=int,
help="Group size for kv cache quantization.",
default=64,
)
parser.add_argument(
"--kv-bits",
type=int,
help="Number of bits for kv cache quantization.",
default=8,
)
return parser return parser
@ -150,7 +167,11 @@ def main():
using_cache = args.prompt_cache_file is not None using_cache = args.prompt_cache_file is not None
if using_cache: if using_cache:
prompt_cache, metadata = load_prompt_cache( prompt_cache, metadata = load_prompt_cache(
args.prompt_cache_file, return_metadata=True args.prompt_cache_file,
return_metadata=True,
quantized_kv=args.quantized_kv,
kv_group_size=args.kv_group_size,
kv_bits=args.kv_bits,
) )
# Building tokenizer_config # Building tokenizer_config
@ -227,6 +248,9 @@ def main():
top_p=args.top_p, top_p=args.top_p,
max_kv_size=args.max_kv_size, max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None, prompt_cache=prompt_cache if using_cache else None,
quantized_kv=args.quantized_kv,
kv_group_size=args.kv_group_size,
kv_bits=args.kv_bits,
) )
if not args.verbose: if not args.verbose:
print(response) print(response)

View File

@ -7,7 +7,13 @@ import mlx.nn as nn
from mlx.utils import tree_flatten, tree_unflatten from mlx.utils import tree_flatten, tree_unflatten
def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]: def make_prompt_cache(
model: nn.Module,
max_kv_size: Optional[int] = None,
quantized_kv: bool = False,
kv_group_size: int = 64,
kv_bits: int = 8,
) -> List[Any]:
""" """
Construct the model's cache for use when cgeneration. Construct the model's cache for use when cgeneration.
@ -24,7 +30,12 @@ def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> Li
return model.make_cache() return model.make_cache()
num_layers = len(model.layers) num_layers = len(model.layers)
if max_kv_size is not None: if quantized_kv:
return [
QuantizedKVCache(group_size=kv_group_size, bits=kv_bits)
for _ in range(num_layers)
]
elif max_kv_size is not None:
return [ return [
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers) RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
] ]
@ -51,7 +62,9 @@ def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str]
mx.save_safetensors(file_name, cache_data, cache_metadata) mx.save_safetensors(file_name, cache_data, cache_metadata)
def load_prompt_cache(file_name, return_metadata=False): def load_prompt_cache(
file_name, return_metadata=False, quantized_kv=False, kv_group_size=64, kv_bits=8
):
""" """
Load a prompt cache from a file. Load a prompt cache from a file.
@ -72,6 +85,8 @@ def load_prompt_cache(file_name, return_metadata=False):
for c, state, meta_state in zip(cache, arrays, info): for c, state, meta_state in zip(cache, arrays, info):
c.state = state c.state = state
c.meta_state = meta_state c.meta_state = meta_state
if quantized_kv:
cache = [c.to_quantized(group_size=kv_group_size, bits=kv_bits) for c in cache]
if return_metadata: if return_metadata:
return cache, metadata return cache, metadata
return cache return cache
@ -126,6 +141,86 @@ class _BaseCache:
return False return False
class QuantizedKVCache(_BaseCache):
def __init__(self, group_size: int = 64, bits: int = 4):
self.keys = None
self.values = None
self.offset = 0
self.step = 256
self.group_size = group_size
self.bits = bits
def update_and_fetch(self, keys, values):
B, n_kv_heads, num_steps, k_head_dim = keys.shape
prev = self.offset
if self.keys is None or (prev + num_steps) > self.keys[0].shape[2]:
el_per_int = 8 * mx.uint32.size // self.bits
n_steps = (self.step + keys[0].shape[2] - 1) // self.step
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim // el_per_int)
scales_dim = k_head_dim // self.group_size
k_scale_shape = k_shape[:-1] + (scales_dim,)
v_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim // el_per_int)
scale_bias_init = lambda: mx.zeros(k_scale_shape, keys.dtype)
new_k = (mx.zeros(k_shape, mx.uint32), scale_bias_init(), scale_bias_init())
new_v = (mx.zeros(v_shape, mx.uint32), scale_bias_init(), scale_bias_init())
if self.keys is not None:
if prev % self.step != 0:
self.keys = tuple(x[..., :prev, :] for x in self.keys)
self.values = tuple(x[..., :prev, :] for x in self.values)
self.keys = tuple(
mx.concatenate([self.keys[i], new_k[i]], axis=2) for i in range(3)
)
self.values = tuple(
mx.concatenate([self.values[i], new_v[i]], axis=2) for i in range(3)
)
else:
self.keys, self.values = new_k, new_v
self.offset += num_steps
if num_steps > 1:
keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits)
values = mx.quantize(values, group_size=self.group_size, bits=self.bits)
for i in range(len(self.keys)):
self.keys[i][..., prev : self.offset, :] = keys[i]
self.values[i][..., prev : self.offset, :] = values[i]
else:
outputs = mx.fast.quantized_kv_update(
keys,
values,
*self.keys,
*self.values,
prev,
group_size=self.group_size,
bits=self.bits
)
self.keys = outputs[:3]
self.values = outputs[3:]
return (
tuple(x[..., : self.offset, :] for x in self.keys),
tuple(x[..., : self.offset, :] for x in self.values),
)
def state(self):
return self.keys, self.values
@classmethod
def from_cache(
cls, cache: _BaseCache, group_size: int = 64, bits: int = 4
) -> "QuantizedKVCache":
quant_cache = cls(group_size=group_size, bits=bits)
quant_cache.offset = cache.offset
quant_cache.keys = mx.quantize(cache.keys, group_size=group_size, bits=bits)
quant_cache.values = mx.quantize(cache.values, group_size=group_size, bits=bits)
return quant_cache
class KVCache(_BaseCache): class KVCache(_BaseCache):
def __init__(self): def __init__(self):
self.keys = None self.keys = None
@ -180,6 +275,9 @@ class KVCache(_BaseCache):
self.offset -= n self.offset -= n
return n return n
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
return QuantizedKVCache.from_cache(self, group_size=group_size, bits=bits)
class RotatingKVCache(_BaseCache): class RotatingKVCache(_BaseCache):

View File

@ -1,12 +1,13 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask
from .cache import QuantizedKVCache
@dataclass @dataclass
@ -190,9 +191,21 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
if isinstance(cache, QuantizedKVCache):
output = mx.fast.quantized_scaled_dot_product_attention(
queries,
*keys,
*values,
scale=self.scale,
mask=mask,
group_size=cache.group_size,
bits=cache.bits,
)
else:
output = mx.fast.scaled_dot_product_attention( output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -19,7 +19,7 @@ from mlx.utils import tree_flatten, tree_reduce
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
# Local imports # Local imports
from .models import base, cache from .models import cache
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import dequantize as dequantize_model from .tuner.utils import dequantize as dequantize_model
@ -173,6 +173,9 @@ def generate_step(
prompt_cache: Optional[Any] = None, prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None, logit_bias: Optional[Dict[int, float]] = None,
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
quantized_kv: bool = False,
kv_group_size: int = 64,
kv_bits: int = 8,
) -> Generator[Tuple[mx.array, mx.array], None, None]: ) -> Generator[Tuple[mx.array, mx.array], None, None]:
""" """
A generator producing token ids based on the given prompt from the model. A generator producing token ids based on the given prompt from the model.
@ -255,7 +258,13 @@ def generate_step(
# Create the KV cache for generation # Create the KV cache for generation
if prompt_cache is None: if prompt_cache is None:
prompt_cache = cache.make_prompt_cache(model, max_kv_size) prompt_cache = cache.make_prompt_cache(
model,
max_kv_size=max_kv_size,
quantized_kv=quantized_kv,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
elif len(prompt_cache) != len(model.layers): elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.") raise ValueError("Wrong number of layers in the prompt cache.")