add QuantizedKVCache

This commit is contained in:
Alex Barron
2024-10-26 00:23:46 -07:00
parent 9f34fdbda4
commit 48655a7f83
4 changed files with 154 additions and 10 deletions

View File

@@ -19,7 +19,7 @@ from mlx.utils import tree_flatten, tree_reduce
from transformers import PreTrainedTokenizer
# Local imports
from .models import base, cache
from .models import cache
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import dequantize as dequantize_model
@@ -173,6 +173,9 @@ def generate_step(
prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None,
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
quantized_kv: bool = False,
kv_group_size: int = 64,
kv_bits: int = 8,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@@ -255,7 +258,13 @@ def generate_step(
# Create the KV cache for generation
if prompt_cache is None:
prompt_cache = cache.make_prompt_cache(model, max_kv_size)
prompt_cache = cache.make_prompt_cache(
model,
max_kv_size=max_kv_size,
quantized_kv=quantized_kv,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.")