mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 21:01:32 +08:00
add QuantizedKVCache
This commit is contained in:
@@ -7,7 +7,13 @@ import mlx.nn as nn
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]:
|
||||
def make_prompt_cache(
|
||||
model: nn.Module,
|
||||
max_kv_size: Optional[int] = None,
|
||||
quantized_kv: bool = False,
|
||||
kv_group_size: int = 64,
|
||||
kv_bits: int = 8,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Construct the model's cache for use when cgeneration.
|
||||
|
||||
@@ -24,7 +30,12 @@ def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> Li
|
||||
return model.make_cache()
|
||||
|
||||
num_layers = len(model.layers)
|
||||
if max_kv_size is not None:
|
||||
if quantized_kv:
|
||||
return [
|
||||
QuantizedKVCache(group_size=kv_group_size, bits=kv_bits)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
elif max_kv_size is not None:
|
||||
return [
|
||||
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
|
||||
]
|
||||
@@ -51,7 +62,9 @@ def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str]
|
||||
mx.save_safetensors(file_name, cache_data, cache_metadata)
|
||||
|
||||
|
||||
def load_prompt_cache(file_name, return_metadata=False):
|
||||
def load_prompt_cache(
|
||||
file_name, return_metadata=False, quantized_kv=False, kv_group_size=64, kv_bits=8
|
||||
):
|
||||
"""
|
||||
Load a prompt cache from a file.
|
||||
|
||||
@@ -72,6 +85,8 @@ def load_prompt_cache(file_name, return_metadata=False):
|
||||
for c, state, meta_state in zip(cache, arrays, info):
|
||||
c.state = state
|
||||
c.meta_state = meta_state
|
||||
if quantized_kv:
|
||||
cache = [c.to_quantized(group_size=kv_group_size, bits=kv_bits) for c in cache]
|
||||
if return_metadata:
|
||||
return cache, metadata
|
||||
return cache
|
||||
@@ -126,6 +141,86 @@ class _BaseCache:
|
||||
return False
|
||||
|
||||
|
||||
class QuantizedKVCache(_BaseCache):
|
||||
def __init__(self, group_size: int = 64, bits: int = 4):
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.step = 256
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
B, n_kv_heads, num_steps, k_head_dim = keys.shape
|
||||
prev = self.offset
|
||||
|
||||
if self.keys is None or (prev + num_steps) > self.keys[0].shape[2]:
|
||||
el_per_int = 8 * mx.uint32.size // self.bits
|
||||
n_steps = (self.step + keys[0].shape[2] - 1) // self.step
|
||||
|
||||
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim // el_per_int)
|
||||
scales_dim = k_head_dim // self.group_size
|
||||
k_scale_shape = k_shape[:-1] + (scales_dim,)
|
||||
v_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim // el_per_int)
|
||||
|
||||
scale_bias_init = lambda: mx.zeros(k_scale_shape, keys.dtype)
|
||||
new_k = (mx.zeros(k_shape, mx.uint32), scale_bias_init(), scale_bias_init())
|
||||
new_v = (mx.zeros(v_shape, mx.uint32), scale_bias_init(), scale_bias_init())
|
||||
|
||||
if self.keys is not None:
|
||||
if prev % self.step != 0:
|
||||
self.keys = tuple(x[..., :prev, :] for x in self.keys)
|
||||
self.values = tuple(x[..., :prev, :] for x in self.values)
|
||||
self.keys = tuple(
|
||||
mx.concatenate([self.keys[i], new_k[i]], axis=2) for i in range(3)
|
||||
)
|
||||
self.values = tuple(
|
||||
mx.concatenate([self.values[i], new_v[i]], axis=2) for i in range(3)
|
||||
)
|
||||
else:
|
||||
self.keys, self.values = new_k, new_v
|
||||
|
||||
self.offset += num_steps
|
||||
|
||||
if num_steps > 1:
|
||||
keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits)
|
||||
values = mx.quantize(values, group_size=self.group_size, bits=self.bits)
|
||||
for i in range(len(self.keys)):
|
||||
self.keys[i][..., prev : self.offset, :] = keys[i]
|
||||
self.values[i][..., prev : self.offset, :] = values[i]
|
||||
|
||||
else:
|
||||
outputs = mx.fast.quantized_kv_update(
|
||||
keys,
|
||||
values,
|
||||
*self.keys,
|
||||
*self.values,
|
||||
prev,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits
|
||||
)
|
||||
self.keys = outputs[:3]
|
||||
self.values = outputs[3:]
|
||||
|
||||
return (
|
||||
tuple(x[..., : self.offset, :] for x in self.keys),
|
||||
tuple(x[..., : self.offset, :] for x in self.values),
|
||||
)
|
||||
|
||||
def state(self):
|
||||
return self.keys, self.values
|
||||
|
||||
@classmethod
|
||||
def from_cache(
|
||||
cls, cache: _BaseCache, group_size: int = 64, bits: int = 4
|
||||
) -> "QuantizedKVCache":
|
||||
quant_cache = cls(group_size=group_size, bits=bits)
|
||||
quant_cache.offset = cache.offset
|
||||
quant_cache.keys = mx.quantize(cache.keys, group_size=group_size, bits=bits)
|
||||
quant_cache.values = mx.quantize(cache.values, group_size=group_size, bits=bits)
|
||||
return quant_cache
|
||||
|
||||
|
||||
class KVCache(_BaseCache):
|
||||
def __init__(self):
|
||||
self.keys = None
|
||||
@@ -180,6 +275,9 @@ class KVCache(_BaseCache):
|
||||
self.offset -= n
|
||||
return n
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
||||
return QuantizedKVCache.from_cache(self, group_size=group_size, bits=bits)
|
||||
|
||||
|
||||
class RotatingKVCache(_BaseCache):
|
||||
|
||||
|
Reference in New Issue
Block a user