add QuantizedKVCache

This commit is contained in:
Alex Barron
2024-10-26 00:23:46 -07:00
parent 9f34fdbda4
commit 48655a7f83
4 changed files with 154 additions and 10 deletions

View File

@@ -7,7 +7,13 @@ import mlx.nn as nn
from mlx.utils import tree_flatten, tree_unflatten
def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]:
def make_prompt_cache(
model: nn.Module,
max_kv_size: Optional[int] = None,
quantized_kv: bool = False,
kv_group_size: int = 64,
kv_bits: int = 8,
) -> List[Any]:
"""
Construct the model's cache for use when cgeneration.
@@ -24,7 +30,12 @@ def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> Li
return model.make_cache()
num_layers = len(model.layers)
if max_kv_size is not None:
if quantized_kv:
return [
QuantizedKVCache(group_size=kv_group_size, bits=kv_bits)
for _ in range(num_layers)
]
elif max_kv_size is not None:
return [
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
]
@@ -51,7 +62,9 @@ def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str]
mx.save_safetensors(file_name, cache_data, cache_metadata)
def load_prompt_cache(file_name, return_metadata=False):
def load_prompt_cache(
file_name, return_metadata=False, quantized_kv=False, kv_group_size=64, kv_bits=8
):
"""
Load a prompt cache from a file.
@@ -72,6 +85,8 @@ def load_prompt_cache(file_name, return_metadata=False):
for c, state, meta_state in zip(cache, arrays, info):
c.state = state
c.meta_state = meta_state
if quantized_kv:
cache = [c.to_quantized(group_size=kv_group_size, bits=kv_bits) for c in cache]
if return_metadata:
return cache, metadata
return cache
@@ -126,6 +141,86 @@ class _BaseCache:
return False
class QuantizedKVCache(_BaseCache):
def __init__(self, group_size: int = 64, bits: int = 4):
self.keys = None
self.values = None
self.offset = 0
self.step = 256
self.group_size = group_size
self.bits = bits
def update_and_fetch(self, keys, values):
B, n_kv_heads, num_steps, k_head_dim = keys.shape
prev = self.offset
if self.keys is None or (prev + num_steps) > self.keys[0].shape[2]:
el_per_int = 8 * mx.uint32.size // self.bits
n_steps = (self.step + keys[0].shape[2] - 1) // self.step
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim // el_per_int)
scales_dim = k_head_dim // self.group_size
k_scale_shape = k_shape[:-1] + (scales_dim,)
v_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim // el_per_int)
scale_bias_init = lambda: mx.zeros(k_scale_shape, keys.dtype)
new_k = (mx.zeros(k_shape, mx.uint32), scale_bias_init(), scale_bias_init())
new_v = (mx.zeros(v_shape, mx.uint32), scale_bias_init(), scale_bias_init())
if self.keys is not None:
if prev % self.step != 0:
self.keys = tuple(x[..., :prev, :] for x in self.keys)
self.values = tuple(x[..., :prev, :] for x in self.values)
self.keys = tuple(
mx.concatenate([self.keys[i], new_k[i]], axis=2) for i in range(3)
)
self.values = tuple(
mx.concatenate([self.values[i], new_v[i]], axis=2) for i in range(3)
)
else:
self.keys, self.values = new_k, new_v
self.offset += num_steps
if num_steps > 1:
keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits)
values = mx.quantize(values, group_size=self.group_size, bits=self.bits)
for i in range(len(self.keys)):
self.keys[i][..., prev : self.offset, :] = keys[i]
self.values[i][..., prev : self.offset, :] = values[i]
else:
outputs = mx.fast.quantized_kv_update(
keys,
values,
*self.keys,
*self.values,
prev,
group_size=self.group_size,
bits=self.bits
)
self.keys = outputs[:3]
self.values = outputs[3:]
return (
tuple(x[..., : self.offset, :] for x in self.keys),
tuple(x[..., : self.offset, :] for x in self.values),
)
def state(self):
return self.keys, self.values
@classmethod
def from_cache(
cls, cache: _BaseCache, group_size: int = 64, bits: int = 4
) -> "QuantizedKVCache":
quant_cache = cls(group_size=group_size, bits=bits)
quant_cache.offset = cache.offset
quant_cache.keys = mx.quantize(cache.keys, group_size=group_size, bits=bits)
quant_cache.values = mx.quantize(cache.values, group_size=group_size, bits=bits)
return quant_cache
class KVCache(_BaseCache):
def __init__(self):
self.keys = None
@@ -180,6 +275,9 @@ class KVCache(_BaseCache):
self.offset -= n
return n
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
return QuantizedKVCache.from_cache(self, group_size=group_size, bits=bits)
class RotatingKVCache(_BaseCache):