mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
add QuantizedKVCache
This commit is contained in:
parent
9f34fdbda4
commit
48655a7f83
@ -107,6 +107,23 @@ def setup_arg_parser():
|
|||||||
default=None,
|
default=None,
|
||||||
help="A file containing saved KV caches to avoid recomputing them",
|
help="A file containing saved KV caches to avoid recomputing them",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantized-kv",
|
||||||
|
help="Whether to quantize the KV cache.",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-group-size",
|
||||||
|
type=int,
|
||||||
|
help="Group size for kv cache quantization.",
|
||||||
|
default=64,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-bits",
|
||||||
|
type=int,
|
||||||
|
help="Number of bits for kv cache quantization.",
|
||||||
|
default=8,
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -150,7 +167,11 @@ def main():
|
|||||||
using_cache = args.prompt_cache_file is not None
|
using_cache = args.prompt_cache_file is not None
|
||||||
if using_cache:
|
if using_cache:
|
||||||
prompt_cache, metadata = load_prompt_cache(
|
prompt_cache, metadata = load_prompt_cache(
|
||||||
args.prompt_cache_file, return_metadata=True
|
args.prompt_cache_file,
|
||||||
|
return_metadata=True,
|
||||||
|
quantized_kv=args.quantized_kv,
|
||||||
|
kv_group_size=args.kv_group_size,
|
||||||
|
kv_bits=args.kv_bits,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Building tokenizer_config
|
# Building tokenizer_config
|
||||||
@ -227,6 +248,9 @@ def main():
|
|||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
max_kv_size=args.max_kv_size,
|
max_kv_size=args.max_kv_size,
|
||||||
prompt_cache=prompt_cache if using_cache else None,
|
prompt_cache=prompt_cache if using_cache else None,
|
||||||
|
quantized_kv=args.quantized_kv,
|
||||||
|
kv_group_size=args.kv_group_size,
|
||||||
|
kv_bits=args.kv_bits,
|
||||||
)
|
)
|
||||||
if not args.verbose:
|
if not args.verbose:
|
||||||
print(response)
|
print(response)
|
||||||
|
@ -7,7 +7,13 @@ import mlx.nn as nn
|
|||||||
from mlx.utils import tree_flatten, tree_unflatten
|
from mlx.utils import tree_flatten, tree_unflatten
|
||||||
|
|
||||||
|
|
||||||
def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]:
|
def make_prompt_cache(
|
||||||
|
model: nn.Module,
|
||||||
|
max_kv_size: Optional[int] = None,
|
||||||
|
quantized_kv: bool = False,
|
||||||
|
kv_group_size: int = 64,
|
||||||
|
kv_bits: int = 8,
|
||||||
|
) -> List[Any]:
|
||||||
"""
|
"""
|
||||||
Construct the model's cache for use when cgeneration.
|
Construct the model's cache for use when cgeneration.
|
||||||
|
|
||||||
@ -24,7 +30,12 @@ def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> Li
|
|||||||
return model.make_cache()
|
return model.make_cache()
|
||||||
|
|
||||||
num_layers = len(model.layers)
|
num_layers = len(model.layers)
|
||||||
if max_kv_size is not None:
|
if quantized_kv:
|
||||||
|
return [
|
||||||
|
QuantizedKVCache(group_size=kv_group_size, bits=kv_bits)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
elif max_kv_size is not None:
|
||||||
return [
|
return [
|
||||||
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
|
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
|
||||||
]
|
]
|
||||||
@ -51,7 +62,9 @@ def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str]
|
|||||||
mx.save_safetensors(file_name, cache_data, cache_metadata)
|
mx.save_safetensors(file_name, cache_data, cache_metadata)
|
||||||
|
|
||||||
|
|
||||||
def load_prompt_cache(file_name, return_metadata=False):
|
def load_prompt_cache(
|
||||||
|
file_name, return_metadata=False, quantized_kv=False, kv_group_size=64, kv_bits=8
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load a prompt cache from a file.
|
Load a prompt cache from a file.
|
||||||
|
|
||||||
@ -72,6 +85,8 @@ def load_prompt_cache(file_name, return_metadata=False):
|
|||||||
for c, state, meta_state in zip(cache, arrays, info):
|
for c, state, meta_state in zip(cache, arrays, info):
|
||||||
c.state = state
|
c.state = state
|
||||||
c.meta_state = meta_state
|
c.meta_state = meta_state
|
||||||
|
if quantized_kv:
|
||||||
|
cache = [c.to_quantized(group_size=kv_group_size, bits=kv_bits) for c in cache]
|
||||||
if return_metadata:
|
if return_metadata:
|
||||||
return cache, metadata
|
return cache, metadata
|
||||||
return cache
|
return cache
|
||||||
@ -126,6 +141,86 @@ class _BaseCache:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizedKVCache(_BaseCache):
|
||||||
|
def __init__(self, group_size: int = 64, bits: int = 4):
|
||||||
|
self.keys = None
|
||||||
|
self.values = None
|
||||||
|
self.offset = 0
|
||||||
|
self.step = 256
|
||||||
|
self.group_size = group_size
|
||||||
|
self.bits = bits
|
||||||
|
|
||||||
|
def update_and_fetch(self, keys, values):
|
||||||
|
B, n_kv_heads, num_steps, k_head_dim = keys.shape
|
||||||
|
prev = self.offset
|
||||||
|
|
||||||
|
if self.keys is None or (prev + num_steps) > self.keys[0].shape[2]:
|
||||||
|
el_per_int = 8 * mx.uint32.size // self.bits
|
||||||
|
n_steps = (self.step + keys[0].shape[2] - 1) // self.step
|
||||||
|
|
||||||
|
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim // el_per_int)
|
||||||
|
scales_dim = k_head_dim // self.group_size
|
||||||
|
k_scale_shape = k_shape[:-1] + (scales_dim,)
|
||||||
|
v_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim // el_per_int)
|
||||||
|
|
||||||
|
scale_bias_init = lambda: mx.zeros(k_scale_shape, keys.dtype)
|
||||||
|
new_k = (mx.zeros(k_shape, mx.uint32), scale_bias_init(), scale_bias_init())
|
||||||
|
new_v = (mx.zeros(v_shape, mx.uint32), scale_bias_init(), scale_bias_init())
|
||||||
|
|
||||||
|
if self.keys is not None:
|
||||||
|
if prev % self.step != 0:
|
||||||
|
self.keys = tuple(x[..., :prev, :] for x in self.keys)
|
||||||
|
self.values = tuple(x[..., :prev, :] for x in self.values)
|
||||||
|
self.keys = tuple(
|
||||||
|
mx.concatenate([self.keys[i], new_k[i]], axis=2) for i in range(3)
|
||||||
|
)
|
||||||
|
self.values = tuple(
|
||||||
|
mx.concatenate([self.values[i], new_v[i]], axis=2) for i in range(3)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.keys, self.values = new_k, new_v
|
||||||
|
|
||||||
|
self.offset += num_steps
|
||||||
|
|
||||||
|
if num_steps > 1:
|
||||||
|
keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits)
|
||||||
|
values = mx.quantize(values, group_size=self.group_size, bits=self.bits)
|
||||||
|
for i in range(len(self.keys)):
|
||||||
|
self.keys[i][..., prev : self.offset, :] = keys[i]
|
||||||
|
self.values[i][..., prev : self.offset, :] = values[i]
|
||||||
|
|
||||||
|
else:
|
||||||
|
outputs = mx.fast.quantized_kv_update(
|
||||||
|
keys,
|
||||||
|
values,
|
||||||
|
*self.keys,
|
||||||
|
*self.values,
|
||||||
|
prev,
|
||||||
|
group_size=self.group_size,
|
||||||
|
bits=self.bits
|
||||||
|
)
|
||||||
|
self.keys = outputs[:3]
|
||||||
|
self.values = outputs[3:]
|
||||||
|
|
||||||
|
return (
|
||||||
|
tuple(x[..., : self.offset, :] for x in self.keys),
|
||||||
|
tuple(x[..., : self.offset, :] for x in self.values),
|
||||||
|
)
|
||||||
|
|
||||||
|
def state(self):
|
||||||
|
return self.keys, self.values
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_cache(
|
||||||
|
cls, cache: _BaseCache, group_size: int = 64, bits: int = 4
|
||||||
|
) -> "QuantizedKVCache":
|
||||||
|
quant_cache = cls(group_size=group_size, bits=bits)
|
||||||
|
quant_cache.offset = cache.offset
|
||||||
|
quant_cache.keys = mx.quantize(cache.keys, group_size=group_size, bits=bits)
|
||||||
|
quant_cache.values = mx.quantize(cache.values, group_size=group_size, bits=bits)
|
||||||
|
return quant_cache
|
||||||
|
|
||||||
|
|
||||||
class KVCache(_BaseCache):
|
class KVCache(_BaseCache):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.keys = None
|
self.keys = None
|
||||||
@ -180,6 +275,9 @@ class KVCache(_BaseCache):
|
|||||||
self.offset -= n
|
self.offset -= n
|
||||||
return n
|
return n
|
||||||
|
|
||||||
|
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
||||||
|
return QuantizedKVCache.from_cache(self, group_size=group_size, bits=bits)
|
||||||
|
|
||||||
|
|
||||||
class RotatingKVCache(_BaseCache):
|
class RotatingKVCache(_BaseCache):
|
||||||
|
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask
|
||||||
|
from .cache import QuantizedKVCache
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -190,9 +191,21 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
if isinstance(cache, QuantizedKVCache):
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
output = mx.fast.quantized_scaled_dot_product_attention(
|
||||||
)
|
queries,
|
||||||
|
*keys,
|
||||||
|
*values,
|
||||||
|
scale=self.scale,
|
||||||
|
mask=mask,
|
||||||
|
group_size=cache.group_size,
|
||||||
|
bits=cache.bits,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ from mlx.utils import tree_flatten, tree_reduce
|
|||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
from .models import base, cache
|
from .models import cache
|
||||||
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
|
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
|
||||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||||
from .tuner.utils import dequantize as dequantize_model
|
from .tuner.utils import dequantize as dequantize_model
|
||||||
@ -173,6 +173,9 @@ def generate_step(
|
|||||||
prompt_cache: Optional[Any] = None,
|
prompt_cache: Optional[Any] = None,
|
||||||
logit_bias: Optional[Dict[int, float]] = None,
|
logit_bias: Optional[Dict[int, float]] = None,
|
||||||
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||||
|
quantized_kv: bool = False,
|
||||||
|
kv_group_size: int = 64,
|
||||||
|
kv_bits: int = 8,
|
||||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing token ids based on the given prompt from the model.
|
A generator producing token ids based on the given prompt from the model.
|
||||||
@ -255,7 +258,13 @@ def generate_step(
|
|||||||
|
|
||||||
# Create the KV cache for generation
|
# Create the KV cache for generation
|
||||||
if prompt_cache is None:
|
if prompt_cache is None:
|
||||||
prompt_cache = cache.make_prompt_cache(model, max_kv_size)
|
prompt_cache = cache.make_prompt_cache(
|
||||||
|
model,
|
||||||
|
max_kv_size=max_kv_size,
|
||||||
|
quantized_kv=quantized_kv,
|
||||||
|
kv_group_size=kv_group_size,
|
||||||
|
kv_bits=kv_bits,
|
||||||
|
)
|
||||||
elif len(prompt_cache) != len(model.layers):
|
elif len(prompt_cache) != len(model.layers):
|
||||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user