mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
add QuantizedKVCache
This commit is contained in:
parent
9f34fdbda4
commit
48655a7f83
@ -107,6 +107,23 @@ def setup_arg_parser():
|
||||
default=None,
|
||||
help="A file containing saved KV caches to avoid recomputing them",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantized-kv",
|
||||
help="Whether to quantize the KV cache.",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-group-size",
|
||||
type=int,
|
||||
help="Group size for kv cache quantization.",
|
||||
default=64,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-bits",
|
||||
type=int,
|
||||
help="Number of bits for kv cache quantization.",
|
||||
default=8,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@ -150,7 +167,11 @@ def main():
|
||||
using_cache = args.prompt_cache_file is not None
|
||||
if using_cache:
|
||||
prompt_cache, metadata = load_prompt_cache(
|
||||
args.prompt_cache_file, return_metadata=True
|
||||
args.prompt_cache_file,
|
||||
return_metadata=True,
|
||||
quantized_kv=args.quantized_kv,
|
||||
kv_group_size=args.kv_group_size,
|
||||
kv_bits=args.kv_bits,
|
||||
)
|
||||
|
||||
# Building tokenizer_config
|
||||
@ -227,6 +248,9 @@ def main():
|
||||
top_p=args.top_p,
|
||||
max_kv_size=args.max_kv_size,
|
||||
prompt_cache=prompt_cache if using_cache else None,
|
||||
quantized_kv=args.quantized_kv,
|
||||
kv_group_size=args.kv_group_size,
|
||||
kv_bits=args.kv_bits,
|
||||
)
|
||||
if not args.verbose:
|
||||
print(response)
|
||||
|
@ -7,7 +7,13 @@ import mlx.nn as nn
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]:
|
||||
def make_prompt_cache(
|
||||
model: nn.Module,
|
||||
max_kv_size: Optional[int] = None,
|
||||
quantized_kv: bool = False,
|
||||
kv_group_size: int = 64,
|
||||
kv_bits: int = 8,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Construct the model's cache for use when cgeneration.
|
||||
|
||||
@ -24,7 +30,12 @@ def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> Li
|
||||
return model.make_cache()
|
||||
|
||||
num_layers = len(model.layers)
|
||||
if max_kv_size is not None:
|
||||
if quantized_kv:
|
||||
return [
|
||||
QuantizedKVCache(group_size=kv_group_size, bits=kv_bits)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
elif max_kv_size is not None:
|
||||
return [
|
||||
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
|
||||
]
|
||||
@ -51,7 +62,9 @@ def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str]
|
||||
mx.save_safetensors(file_name, cache_data, cache_metadata)
|
||||
|
||||
|
||||
def load_prompt_cache(file_name, return_metadata=False):
|
||||
def load_prompt_cache(
|
||||
file_name, return_metadata=False, quantized_kv=False, kv_group_size=64, kv_bits=8
|
||||
):
|
||||
"""
|
||||
Load a prompt cache from a file.
|
||||
|
||||
@ -72,6 +85,8 @@ def load_prompt_cache(file_name, return_metadata=False):
|
||||
for c, state, meta_state in zip(cache, arrays, info):
|
||||
c.state = state
|
||||
c.meta_state = meta_state
|
||||
if quantized_kv:
|
||||
cache = [c.to_quantized(group_size=kv_group_size, bits=kv_bits) for c in cache]
|
||||
if return_metadata:
|
||||
return cache, metadata
|
||||
return cache
|
||||
@ -126,6 +141,86 @@ class _BaseCache:
|
||||
return False
|
||||
|
||||
|
||||
class QuantizedKVCache(_BaseCache):
|
||||
def __init__(self, group_size: int = 64, bits: int = 4):
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.step = 256
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
B, n_kv_heads, num_steps, k_head_dim = keys.shape
|
||||
prev = self.offset
|
||||
|
||||
if self.keys is None or (prev + num_steps) > self.keys[0].shape[2]:
|
||||
el_per_int = 8 * mx.uint32.size // self.bits
|
||||
n_steps = (self.step + keys[0].shape[2] - 1) // self.step
|
||||
|
||||
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim // el_per_int)
|
||||
scales_dim = k_head_dim // self.group_size
|
||||
k_scale_shape = k_shape[:-1] + (scales_dim,)
|
||||
v_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim // el_per_int)
|
||||
|
||||
scale_bias_init = lambda: mx.zeros(k_scale_shape, keys.dtype)
|
||||
new_k = (mx.zeros(k_shape, mx.uint32), scale_bias_init(), scale_bias_init())
|
||||
new_v = (mx.zeros(v_shape, mx.uint32), scale_bias_init(), scale_bias_init())
|
||||
|
||||
if self.keys is not None:
|
||||
if prev % self.step != 0:
|
||||
self.keys = tuple(x[..., :prev, :] for x in self.keys)
|
||||
self.values = tuple(x[..., :prev, :] for x in self.values)
|
||||
self.keys = tuple(
|
||||
mx.concatenate([self.keys[i], new_k[i]], axis=2) for i in range(3)
|
||||
)
|
||||
self.values = tuple(
|
||||
mx.concatenate([self.values[i], new_v[i]], axis=2) for i in range(3)
|
||||
)
|
||||
else:
|
||||
self.keys, self.values = new_k, new_v
|
||||
|
||||
self.offset += num_steps
|
||||
|
||||
if num_steps > 1:
|
||||
keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits)
|
||||
values = mx.quantize(values, group_size=self.group_size, bits=self.bits)
|
||||
for i in range(len(self.keys)):
|
||||
self.keys[i][..., prev : self.offset, :] = keys[i]
|
||||
self.values[i][..., prev : self.offset, :] = values[i]
|
||||
|
||||
else:
|
||||
outputs = mx.fast.quantized_kv_update(
|
||||
keys,
|
||||
values,
|
||||
*self.keys,
|
||||
*self.values,
|
||||
prev,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits
|
||||
)
|
||||
self.keys = outputs[:3]
|
||||
self.values = outputs[3:]
|
||||
|
||||
return (
|
||||
tuple(x[..., : self.offset, :] for x in self.keys),
|
||||
tuple(x[..., : self.offset, :] for x in self.values),
|
||||
)
|
||||
|
||||
def state(self):
|
||||
return self.keys, self.values
|
||||
|
||||
@classmethod
|
||||
def from_cache(
|
||||
cls, cache: _BaseCache, group_size: int = 64, bits: int = 4
|
||||
) -> "QuantizedKVCache":
|
||||
quant_cache = cls(group_size=group_size, bits=bits)
|
||||
quant_cache.offset = cache.offset
|
||||
quant_cache.keys = mx.quantize(cache.keys, group_size=group_size, bits=bits)
|
||||
quant_cache.values = mx.quantize(cache.values, group_size=group_size, bits=bits)
|
||||
return quant_cache
|
||||
|
||||
|
||||
class KVCache(_BaseCache):
|
||||
def __init__(self):
|
||||
self.keys = None
|
||||
@ -180,6 +275,9 @@ class KVCache(_BaseCache):
|
||||
self.offset -= n
|
||||
return n
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
||||
return QuantizedKVCache.from_cache(self, group_size=group_size, bits=bits)
|
||||
|
||||
|
||||
class RotatingKVCache(_BaseCache):
|
||||
|
||||
|
@ -1,12 +1,13 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
from .cache import QuantizedKVCache
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -190,9 +191,21 @@ class Attention(nn.Module):
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
if isinstance(cache, QuantizedKVCache):
|
||||
output = mx.fast.quantized_scaled_dot_product_attention(
|
||||
queries,
|
||||
*keys,
|
||||
*values,
|
||||
scale=self.scale,
|
||||
mask=mask,
|
||||
group_size=cache.group_size,
|
||||
bits=cache.bits,
|
||||
)
|
||||
else:
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
@ -19,7 +19,7 @@ from mlx.utils import tree_flatten, tree_reduce
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
from .models import base, cache
|
||||
from .models import cache
|
||||
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
|
||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||
from .tuner.utils import dequantize as dequantize_model
|
||||
@ -173,6 +173,9 @@ def generate_step(
|
||||
prompt_cache: Optional[Any] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
quantized_kv: bool = False,
|
||||
kv_group_size: int = 64,
|
||||
kv_bits: int = 8,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing token ids based on the given prompt from the model.
|
||||
@ -255,7 +258,13 @@ def generate_step(
|
||||
|
||||
# Create the KV cache for generation
|
||||
if prompt_cache is None:
|
||||
prompt_cache = cache.make_prompt_cache(model, max_kv_size)
|
||||
prompt_cache = cache.make_prompt_cache(
|
||||
model,
|
||||
max_kv_size=max_kv_size,
|
||||
quantized_kv=quantized_kv,
|
||||
kv_group_size=kv_group_size,
|
||||
kv_bits=kv_bits,
|
||||
)
|
||||
elif len(prompt_cache) != len(model.layers):
|
||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user