mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Quantized KV Cache (#1075)
* add QuantizedKVCache * simplify * add tests * single sdpa function * fix sed * in place * fix tests * support different k and v head dims
This commit is contained in:
@@ -4,10 +4,13 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
|
||||
def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]:
|
||||
def make_prompt_cache(
|
||||
model: nn.Module,
|
||||
max_kv_size: Optional[int] = None,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Construct the model's cache for use when cgeneration.
|
||||
|
||||
@@ -126,6 +129,88 @@ class _BaseCache:
|
||||
return False
|
||||
|
||||
|
||||
class QuantizedKVCache(_BaseCache):
|
||||
def __init__(self, group_size: int = 64, bits: int = 8):
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.step = 256
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
B, n_kv_heads, num_steps, k_head_dim = keys.shape
|
||||
v_head_dim = values.shape[-1]
|
||||
prev = self.offset
|
||||
|
||||
if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]:
|
||||
el_per_int = 8 * mx.uint32.size // self.bits
|
||||
new_steps = (self.step + num_steps - 1) // self.step * self.step
|
||||
shape = (B, n_kv_heads, new_steps)
|
||||
|
||||
def init_quant(dim):
|
||||
return (
|
||||
mx.zeros((*shape, dim // el_per_int), dtype=mx.uint32),
|
||||
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
|
||||
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
|
||||
)
|
||||
|
||||
def expand_quant(x):
|
||||
new_x = mx.zeros((*shape, x.shape[-1]), dtype=x.dtype)
|
||||
return mx.concatenate([x, new_x], axis=-2)
|
||||
|
||||
if self.keys is not None:
|
||||
if prev % self.step != 0:
|
||||
self.keys, self.values = tree_map(
|
||||
lambda x: x[..., :prev, :], (self.keys, self.values)
|
||||
)
|
||||
|
||||
self.keys, self.values = tree_map(
|
||||
expand_quant, (self.keys, self.values)
|
||||
)
|
||||
else:
|
||||
self.keys, self.values = init_quant(k_head_dim), init_quant(v_head_dim)
|
||||
|
||||
self.offset += num_steps
|
||||
|
||||
keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits)
|
||||
values = mx.quantize(values, group_size=self.group_size, bits=self.bits)
|
||||
for i in range(len(self.keys)):
|
||||
self.keys[i][..., prev : self.offset, :] = keys[i]
|
||||
self.values[i][..., prev : self.offset, :] = values[i]
|
||||
|
||||
return tree_map(lambda x: x[..., : self.offset, :], (self.keys, self.values))
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
if self.offset == self.keys[0].shape[2]:
|
||||
return self.keys, self.values
|
||||
else:
|
||||
return tree_map(
|
||||
lambda x: x[..., : self.offset, :], (self.keys, self.values)
|
||||
)
|
||||
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
self.keys, self.values = v
|
||||
|
||||
@property
|
||||
def meta_state(self):
|
||||
return tuple(map(str, (self.step, self.offset, self.group_size, self.bits)))
|
||||
|
||||
@meta_state.setter
|
||||
def meta_state(self, v):
|
||||
self.step, self.offset, self.group_size, self.bits = map(int, v)
|
||||
|
||||
def is_trimmable(self):
|
||||
return True
|
||||
|
||||
def trim(self, n):
|
||||
n = min(self.offset, n)
|
||||
self.offset -= n
|
||||
return n
|
||||
|
||||
|
||||
class KVCache(_BaseCache):
|
||||
def __init__(self):
|
||||
self.keys = None
|
||||
@@ -180,6 +265,16 @@ class KVCache(_BaseCache):
|
||||
self.offset -= n
|
||||
return n
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
||||
quant_cache = QuantizedKVCache(group_size=group_size, bits=bits)
|
||||
quant_cache.offset = self.offset
|
||||
if self.keys is not None:
|
||||
quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits)
|
||||
quant_cache.values = mx.quantize(
|
||||
self.values, group_size=group_size, bits=bits
|
||||
)
|
||||
return quant_cache
|
||||
|
||||
|
||||
class RotatingKVCache(_BaseCache):
|
||||
|
||||
@@ -320,6 +415,9 @@ class RotatingKVCache(_BaseCache):
|
||||
self._idx -= n
|
||||
return n
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
||||
raise NotImplementedError("RotatingKVCache Quantization NYI")
|
||||
|
||||
|
||||
class MambaCache(_BaseCache):
|
||||
def __init__(self):
|
||||
|
Reference in New Issue
Block a user