From 48655a7f831869821c6ea7bb341d9bfc1ad86cbc Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Sat, 26 Oct 2024 00:23:46 -0700 Subject: [PATCH] add QuantizedKVCache --- llms/mlx_lm/generate.py | 26 ++++++++- llms/mlx_lm/models/cache.py | 104 ++++++++++++++++++++++++++++++++++-- llms/mlx_lm/models/llama.py | 21 ++++++-- llms/mlx_lm/utils.py | 13 ++++- 4 files changed, 154 insertions(+), 10 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 0bf98ab2..b099552a 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -107,6 +107,23 @@ def setup_arg_parser(): default=None, help="A file containing saved KV caches to avoid recomputing them", ) + parser.add_argument( + "--quantized-kv", + help="Whether to quantize the KV cache.", + action="store_true", + ) + parser.add_argument( + "--kv-group-size", + type=int, + help="Group size for kv cache quantization.", + default=64, + ) + parser.add_argument( + "--kv-bits", + type=int, + help="Number of bits for kv cache quantization.", + default=8, + ) return parser @@ -150,7 +167,11 @@ def main(): using_cache = args.prompt_cache_file is not None if using_cache: prompt_cache, metadata = load_prompt_cache( - args.prompt_cache_file, return_metadata=True + args.prompt_cache_file, + return_metadata=True, + quantized_kv=args.quantized_kv, + kv_group_size=args.kv_group_size, + kv_bits=args.kv_bits, ) # Building tokenizer_config @@ -227,6 +248,9 @@ def main(): top_p=args.top_p, max_kv_size=args.max_kv_size, prompt_cache=prompt_cache if using_cache else None, + quantized_kv=args.quantized_kv, + kv_group_size=args.kv_group_size, + kv_bits=args.kv_bits, ) if not args.verbose: print(response) diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index a6a56e0a..bd3d4932 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -7,7 +7,13 @@ import mlx.nn as nn from mlx.utils import tree_flatten, tree_unflatten -def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]: +def make_prompt_cache( + model: nn.Module, + max_kv_size: Optional[int] = None, + quantized_kv: bool = False, + kv_group_size: int = 64, + kv_bits: int = 8, +) -> List[Any]: """ Construct the model's cache for use when cgeneration. @@ -24,7 +30,12 @@ def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> Li return model.make_cache() num_layers = len(model.layers) - if max_kv_size is not None: + if quantized_kv: + return [ + QuantizedKVCache(group_size=kv_group_size, bits=kv_bits) + for _ in range(num_layers) + ] + elif max_kv_size is not None: return [ RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers) ] @@ -51,7 +62,9 @@ def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] mx.save_safetensors(file_name, cache_data, cache_metadata) -def load_prompt_cache(file_name, return_metadata=False): +def load_prompt_cache( + file_name, return_metadata=False, quantized_kv=False, kv_group_size=64, kv_bits=8 +): """ Load a prompt cache from a file. @@ -72,6 +85,8 @@ def load_prompt_cache(file_name, return_metadata=False): for c, state, meta_state in zip(cache, arrays, info): c.state = state c.meta_state = meta_state + if quantized_kv: + cache = [c.to_quantized(group_size=kv_group_size, bits=kv_bits) for c in cache] if return_metadata: return cache, metadata return cache @@ -126,6 +141,86 @@ class _BaseCache: return False +class QuantizedKVCache(_BaseCache): + def __init__(self, group_size: int = 64, bits: int = 4): + self.keys = None + self.values = None + self.offset = 0 + self.step = 256 + self.group_size = group_size + self.bits = bits + + def update_and_fetch(self, keys, values): + B, n_kv_heads, num_steps, k_head_dim = keys.shape + prev = self.offset + + if self.keys is None or (prev + num_steps) > self.keys[0].shape[2]: + el_per_int = 8 * mx.uint32.size // self.bits + n_steps = (self.step + keys[0].shape[2] - 1) // self.step + + k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim // el_per_int) + scales_dim = k_head_dim // self.group_size + k_scale_shape = k_shape[:-1] + (scales_dim,) + v_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim // el_per_int) + + scale_bias_init = lambda: mx.zeros(k_scale_shape, keys.dtype) + new_k = (mx.zeros(k_shape, mx.uint32), scale_bias_init(), scale_bias_init()) + new_v = (mx.zeros(v_shape, mx.uint32), scale_bias_init(), scale_bias_init()) + + if self.keys is not None: + if prev % self.step != 0: + self.keys = tuple(x[..., :prev, :] for x in self.keys) + self.values = tuple(x[..., :prev, :] for x in self.values) + self.keys = tuple( + mx.concatenate([self.keys[i], new_k[i]], axis=2) for i in range(3) + ) + self.values = tuple( + mx.concatenate([self.values[i], new_v[i]], axis=2) for i in range(3) + ) + else: + self.keys, self.values = new_k, new_v + + self.offset += num_steps + + if num_steps > 1: + keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits) + values = mx.quantize(values, group_size=self.group_size, bits=self.bits) + for i in range(len(self.keys)): + self.keys[i][..., prev : self.offset, :] = keys[i] + self.values[i][..., prev : self.offset, :] = values[i] + + else: + outputs = mx.fast.quantized_kv_update( + keys, + values, + *self.keys, + *self.values, + prev, + group_size=self.group_size, + bits=self.bits + ) + self.keys = outputs[:3] + self.values = outputs[3:] + + return ( + tuple(x[..., : self.offset, :] for x in self.keys), + tuple(x[..., : self.offset, :] for x in self.values), + ) + + def state(self): + return self.keys, self.values + + @classmethod + def from_cache( + cls, cache: _BaseCache, group_size: int = 64, bits: int = 4 + ) -> "QuantizedKVCache": + quant_cache = cls(group_size=group_size, bits=bits) + quant_cache.offset = cache.offset + quant_cache.keys = mx.quantize(cache.keys, group_size=group_size, bits=bits) + quant_cache.values = mx.quantize(cache.values, group_size=group_size, bits=bits) + return quant_cache + + class KVCache(_BaseCache): def __init__(self): self.keys = None @@ -180,6 +275,9 @@ class KVCache(_BaseCache): self.offset -= n return n + def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: + return QuantizedKVCache.from_cache(self, group_size=group_size, bits=bits) + class RotatingKVCache(_BaseCache): diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 7da6b333..ffa52c8b 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -1,12 +1,13 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs, create_attention_mask +from .cache import QuantizedKVCache @dataclass @@ -190,9 +191,21 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask - ) + if isinstance(cache, QuantizedKVCache): + output = mx.fast.quantized_scaled_dot_product_attention( + queries, + *keys, + *values, + scale=self.scale, + mask=mask, + group_size=cache.group_size, + bits=cache.bits, + ) + else: + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 5b437c98..8e87d218 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -19,7 +19,7 @@ from mlx.utils import tree_flatten, tree_reduce from transformers import PreTrainedTokenizer # Local imports -from .models import base, cache +from .models import cache from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import dequantize as dequantize_model @@ -173,6 +173,9 @@ def generate_step( prompt_cache: Optional[Any] = None, logit_bias: Optional[Dict[int, float]] = None, logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + quantized_kv: bool = False, + kv_group_size: int = 64, + kv_bits: int = 8, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -255,7 +258,13 @@ def generate_step( # Create the KV cache for generation if prompt_cache is None: - prompt_cache = cache.make_prompt_cache(model, max_kv_size) + prompt_cache = cache.make_prompt_cache( + model, + max_kv_size=max_kv_size, + quantized_kv=quantized_kv, + kv_group_size=kv_group_size, + kv_bits=kv_bits, + ) elif len(prompt_cache) != len(model.layers): raise ValueError("Wrong number of layers in the prompt cache.")