mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-04 07:44:34 +08:00
Quantized KV Cache (#1075)
* add QuantizedKVCache * simplify * add tests * single sdpa function * fix sed * in place * fix tests * support different k and v head dims
This commit is contained in:
@@ -19,7 +19,7 @@ from mlx.utils import tree_flatten, tree_reduce
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
from .models import base, cache
|
||||
from .models import cache
|
||||
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
|
||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||
from .tuner.utils import dequantize as dequantize_model
|
||||
@@ -159,6 +159,18 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float)
|
||||
return logits
|
||||
|
||||
|
||||
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
|
||||
if (
|
||||
kv_bits is not None
|
||||
and not isinstance(prompt_cache[0], cache.QuantizedKVCache)
|
||||
and prompt_cache[0].offset > quantized_kv_start
|
||||
):
|
||||
for i in range(len(prompt_cache)):
|
||||
prompt_cache[i] = prompt_cache[i].to_quantized(
|
||||
group_size=kv_group_size, bits=kv_bits
|
||||
)
|
||||
|
||||
|
||||
def generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
@@ -173,6 +185,9 @@ def generate_step(
|
||||
prompt_cache: Optional[Any] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
kv_bits: Optional[int] = None,
|
||||
kv_group_size: int = 64,
|
||||
quantized_kv_start: int = 0,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing token ids based on the given prompt from the model.
|
||||
@@ -201,6 +216,11 @@ def generate_step(
|
||||
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||
A list of functions that take tokens and logits and return the processed
|
||||
logits. Default: ``None``.
|
||||
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
||||
None implies no cache quantization. Default: ``None``.
|
||||
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
||||
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
||||
when ``kv_bits`` is non-None. Default: ``0``.
|
||||
|
||||
Yields:
|
||||
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
||||
@@ -255,11 +275,15 @@ def generate_step(
|
||||
|
||||
# Create the KV cache for generation
|
||||
if prompt_cache is None:
|
||||
prompt_cache = cache.make_prompt_cache(model, max_kv_size)
|
||||
prompt_cache = cache.make_prompt_cache(
|
||||
model,
|
||||
max_kv_size=max_kv_size,
|
||||
)
|
||||
elif len(prompt_cache) != len(model.layers):
|
||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||
|
||||
def _step(y):
|
||||
|
||||
logits = model(y[None], cache=prompt_cache)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
@@ -270,6 +294,10 @@ def generate_step(
|
||||
for processor in logits_processor:
|
||||
logits = processor(tokens, logits)
|
||||
|
||||
maybe_quantize_kv_cache(
|
||||
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||
)
|
||||
|
||||
y, logprobs = sample(logits)
|
||||
return y, logprobs.squeeze(0)
|
||||
|
||||
|
Reference in New Issue
Block a user