diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index 04e75a3e..381033b3 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -8,7 +8,13 @@ import time import mlx.core as mx from .models.cache import make_prompt_cache, save_prompt_cache -from .utils import load +from .utils import ( + DEFAULT_KV_BITS, + DEFAULT_KV_GROUP_SIZE, + check_quantized_kv_args, + load, + maybe_quantize_kv_cache, +) def setup_arg_parser(): @@ -70,6 +76,24 @@ def setup_arg_parser(): required=True, help="Message to be processed by the model ('-' reads from stdin)", ) + parser.add_argument( + "--quantized-kv-start", + help="Use a quantized KV cache from this step onwards.", + type=int, + default=None, + ) + parser.add_argument( + "--kv-group-size", + type=int, + help="Group size for kv cache quantization.", + default=DEFAULT_KV_GROUP_SIZE, + ) + parser.add_argument( + "--kv-bits", + type=int, + help="Number of bits for kv cache quantization.", + default=DEFAULT_KV_BITS, + ) return parser @@ -93,6 +117,8 @@ def main(): args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt + check_quantized_kv_args(args.quantized_kv_start, args.kv_group_size, args.kv_bits) + if args.use_default_chat_template: if tokenizer.chat_template is None: tokenizer.chat_template = tokenizer.default_chat_template @@ -127,6 +153,7 @@ def main(): start = time.time() max_msg_len = 0 while y.size > 0: + model(y[:step_size][None], cache=cache) mx.eval([c.state for c in cache]) processed += min(y.size, step_size) @@ -136,6 +163,11 @@ def main(): msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)" max_msg_len = max(max_msg_len, len(msg)) print(msg + " " * (max_msg_len - len(msg)), end="", flush=True) + + cache = maybe_quantize_kv_cache( + cache, args.quantized_kv_start, args.kv_group_size, args.kv_bits + ) + print() print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index b099552a..9b84e7e6 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -6,8 +6,8 @@ import sys import mlx.core as mx -from .models.cache import load_prompt_cache -from .utils import generate, load +from .models.cache import QuantizedKVCache, load_prompt_cache +from .utils import check_quantized_kv_args, generate, load DEFAULT_PROMPT = "hello" DEFAULT_MAX_TOKENS = 100 @@ -108,20 +108,23 @@ def setup_arg_parser(): help="A file containing saved KV caches to avoid recomputing them", ) parser.add_argument( - "--quantized-kv", - help="Whether to quantize the KV cache.", - action="store_true", + "--quantized-kv-start", + help="Use a quantized KV cache from this step onwards.", + type=int, + default=None, ) parser.add_argument( "--kv-group-size", type=int, - help="Group size for kv cache quantization.", + help="Group size for kv cache quantization. " + "--quantized-kv-start must be provided to have an effect.", default=64, ) parser.add_argument( "--kv-bits", type=int, - help="Number of bits for kv cache quantization.", + help="Number of bits for kv cache quantization. " + "--quantized-kv-start must be provided to have an effect.", default=8, ) return parser @@ -169,10 +172,14 @@ def main(): prompt_cache, metadata = load_prompt_cache( args.prompt_cache_file, return_metadata=True, - quantized_kv=args.quantized_kv, - kv_group_size=args.kv_group_size, - kv_bits=args.kv_bits, ) + if args.quantized_kv_start and isinstance(prompt_cache[0], QuantizedKVCache): + raise ValueError( + "Specified `--quantized-kv-start` but cache from " + "`--prompt-cache-file` is already quantized." + ) + + check_quantized_kv_args(args.quantized_kv_start, args.kv_group_size, args.kv_bits) # Building tokenizer_config tokenizer_config = ( @@ -248,7 +255,7 @@ def main(): top_p=args.top_p, max_kv_size=args.max_kv_size, prompt_cache=prompt_cache if using_cache else None, - quantized_kv=args.quantized_kv, + quantized_kv_start=args.quantized_kv_start, kv_group_size=args.kv_group_size, kv_bits=args.kv_bits, ) diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index bd3d4932..0883e573 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -4,15 +4,12 @@ from typing import Any, Dict, List, Optional import mlx.core as mx import mlx.nn as nn -from mlx.utils import tree_flatten, tree_unflatten +from mlx.utils import tree_flatten, tree_map, tree_unflatten def make_prompt_cache( model: nn.Module, max_kv_size: Optional[int] = None, - quantized_kv: bool = False, - kv_group_size: int = 64, - kv_bits: int = 8, ) -> List[Any]: """ Construct the model's cache for use when cgeneration. @@ -30,12 +27,7 @@ def make_prompt_cache( return model.make_cache() num_layers = len(model.layers) - if quantized_kv: - return [ - QuantizedKVCache(group_size=kv_group_size, bits=kv_bits) - for _ in range(num_layers) - ] - elif max_kv_size is not None: + if max_kv_size is not None: return [ RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers) ] @@ -62,9 +54,7 @@ def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] mx.save_safetensors(file_name, cache_data, cache_metadata) -def load_prompt_cache( - file_name, return_metadata=False, quantized_kv=False, kv_group_size=64, kv_bits=8 -): +def load_prompt_cache(file_name, return_metadata=False): """ Load a prompt cache from a file. @@ -85,8 +75,6 @@ def load_prompt_cache( for c, state, meta_state in zip(cache, arrays, info): c.state = state c.meta_state = meta_state - if quantized_kv: - cache = [c.to_quantized(group_size=kv_group_size, bits=kv_bits) for c in cache] if return_metadata: return cache, metadata return cache @@ -141,8 +129,44 @@ class _BaseCache: return False +def quantized_scaled_dot_product_attention( + queries: mx.array, + q_keys: tuple[mx.array, mx.array, mx.array], + q_values: tuple[mx.array, mx.array, mx.array], + scale: float, + mask: Optional[mx.array], + group_size: int = 64, + bits: int = 8, +) -> mx.array: + B, n_q_heads, L, D = queries.shape + n_kv_heads = q_keys[0].shape[-3] + n_repeats = n_q_heads // n_kv_heads + + queries *= scale + + if n_repeats > 1: + queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D)) + q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys) + q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values) + + scores = mx.quantized_matmul( + queries, *q_keys, transpose=True, group_size=group_size, bits=bits + ) + if mask is not None: + scores += mask + scores = mx.softmax(scores, axis=-1, precise=True) + out = mx.quantized_matmul( + scores, *q_values, transpose=False, group_size=group_size, bits=bits + ) + + if n_repeats > 1: + out = mx.reshape(out, (B, n_q_heads, L, D)) + + return out + + class QuantizedKVCache(_BaseCache): - def __init__(self, group_size: int = 64, bits: int = 4): + def __init__(self, group_size: int = 64, bits: int = 8): self.keys = None self.values = None self.offset = 0 @@ -154,71 +178,65 @@ class QuantizedKVCache(_BaseCache): B, n_kv_heads, num_steps, k_head_dim = keys.shape prev = self.offset - if self.keys is None or (prev + num_steps) > self.keys[0].shape[2]: + if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]: el_per_int = 8 * mx.uint32.size // self.bits - n_steps = (self.step + keys[0].shape[2] - 1) // self.step + new_steps = (self.step + num_steps - 1) // self.step * self.step + shape = (B, n_kv_heads, new_steps, k_head_dim // el_per_int) + group_shape = (B, n_kv_heads, new_steps, k_head_dim // self.group_size) - k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim // el_per_int) - scales_dim = k_head_dim // self.group_size - k_scale_shape = k_shape[:-1] + (scales_dim,) - v_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim // el_per_int) + def init_quant(): + return ( + mx.zeros(shape, dtype=mx.uint32), + mx.zeros(group_shape, dtype=keys.dtype), + mx.zeros(group_shape, dtype=keys.dtype), + ) - scale_bias_init = lambda: mx.zeros(k_scale_shape, keys.dtype) - new_k = (mx.zeros(k_shape, mx.uint32), scale_bias_init(), scale_bias_init()) - new_v = (mx.zeros(v_shape, mx.uint32), scale_bias_init(), scale_bias_init()) + def expand_quant(x): + new_x = mx.zeros((B, n_kv_heads, new_steps, x.shape[-1]), dtype=x.dtype) + return mx.concatenate([x, new_x], axis=-2) if self.keys is not None: if prev % self.step != 0: - self.keys = tuple(x[..., :prev, :] for x in self.keys) - self.values = tuple(x[..., :prev, :] for x in self.values) - self.keys = tuple( - mx.concatenate([self.keys[i], new_k[i]], axis=2) for i in range(3) - ) - self.values = tuple( - mx.concatenate([self.values[i], new_v[i]], axis=2) for i in range(3) + self.keys, self.values = tree_map( + lambda x: x[..., :prev, :], (self.keys, self.values) + ) + + self.keys, self.values = tree_map( + expand_quant, (self.keys, self.values) ) else: - self.keys, self.values = new_k, new_v + self.keys, self.values = init_quant(), init_quant() self.offset += num_steps - if num_steps > 1: - keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits) - values = mx.quantize(values, group_size=self.group_size, bits=self.bits) - for i in range(len(self.keys)): - self.keys[i][..., prev : self.offset, :] = keys[i] - self.values[i][..., prev : self.offset, :] = values[i] + keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits) + values = mx.quantize(values, group_size=self.group_size, bits=self.bits) + for i in range(len(self.keys)): + self.keys[i][..., prev : self.offset, :] = keys[i] + self.values[i][..., prev : self.offset, :] = values[i] - else: - outputs = mx.fast.quantized_kv_update( - keys, - values, - *self.keys, - *self.values, - prev, - group_size=self.group_size, - bits=self.bits - ) - self.keys = outputs[:3] - self.values = outputs[3:] - - return ( - tuple(x[..., : self.offset, :] for x in self.keys), - tuple(x[..., : self.offset, :] for x in self.values), - ) + return tree_map(lambda x: x[..., : self.offset, :], (self.keys, self.values)) + @property def state(self): - return self.keys, self.values + if self.offset == self.keys[0].shape[2]: + return self.keys, self.values + else: + return tree_map( + lambda x: x[..., : self.offset, :], (self.keys, self.values) + ) - @classmethod - def from_cache( - cls, cache: _BaseCache, group_size: int = 64, bits: int = 4 - ) -> "QuantizedKVCache": - quant_cache = cls(group_size=group_size, bits=bits) - quant_cache.offset = cache.offset - quant_cache.keys = mx.quantize(cache.keys, group_size=group_size, bits=bits) - quant_cache.values = mx.quantize(cache.values, group_size=group_size, bits=bits) - return quant_cache + @state.setter + def state(self, v): + self.keys, self.values = v + + @property + def meta_state(self): + return tuple(map(str, (self.step, self.offset, self.group_size, self.bits))) + + @meta_state.setter + def meta_state(self, v): + self.step, self.offset, self.group_size, self.bits = map(int, v) class KVCache(_BaseCache): @@ -276,7 +294,11 @@ class KVCache(_BaseCache): return n def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: - return QuantizedKVCache.from_cache(self, group_size=group_size, bits=bits) + quant_cache = QuantizedKVCache(group_size=group_size, bits=bits) + quant_cache.offset = self.offset + quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits) + quant_cache.values = mx.quantize(self.values, group_size=group_size, bits=bits) + return quant_cache class RotatingKVCache(_BaseCache): @@ -418,6 +440,13 @@ class RotatingKVCache(_BaseCache): self._idx -= n return n + def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: + quant_cache = QuantizedKVCache(group_size=group_size, bits=bits) + quant_cache.offset = self.offset + quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits) + quant_cache.values = mx.quantize(self.values, group_size=group_size, bits=bits) + return quant_cache + class MambaCache(_BaseCache): def __init__(self): diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index ffa52c8b..673d322a 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs, create_attention_mask -from .cache import QuantizedKVCache +from .cache import QuantizedKVCache, quantized_scaled_dot_product_attention @dataclass @@ -192,10 +192,10 @@ class Attention(nn.Module): keys = self.rope(keys) if isinstance(cache, QuantizedKVCache): - output = mx.fast.quantized_scaled_dot_product_attention( + output = quantized_scaled_dot_product_attention( queries, - *keys, - *values, + keys, + values, scale=self.scale, mask=mask, group_size=cache.group_size, diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index 4e7858de..0b88e71e 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -7,6 +7,7 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs, create_attention_mask +from .cache import QuantizedKVCache, quantized_scaled_dot_product_attention @dataclass @@ -89,9 +90,20 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask - ) + if isinstance(cache, QuantizedKVCache): + output = quantized_scaled_dot_product_attention( + queries, + keys, + values, + scale=self.scale, + mask=mask, + group_size=cache.group_size, + bits=cache.bits, + ) + else: + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 8e87d218..b28e90ce 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -33,6 +33,9 @@ MODEL_REMAPPING = { MAX_FILE_SIZE_GB = 5 +DEFAULT_KV_GROUP_SIZE = 64 +DEFAULT_KV_BITS = 8 + class ModelNotFoundError(Exception): def __init__(self, message): @@ -159,6 +162,27 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float) return logits +def check_quantized_kv_args(quantized_kv_start, kv_group_size, kv_bits): + if not quantized_kv_start and ( + kv_group_size != DEFAULT_KV_GROUP_SIZE or kv_bits != DEFAULT_KV_BITS + ): + raise ValueError( + "--kv-group-size and --kv-bits only apply when --quantized-kv-start is specified." + ) + + +def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits): + if ( + quantized_kv_start + and prompt_cache[0].offset > quantized_kv_start + and not isinstance(prompt_cache[0], cache.QuantizedKVCache) + ): + return [ + c.to_quantized(group_size=kv_group_size, bits=kv_bits) for c in prompt_cache + ] + return prompt_cache + + def generate_step( prompt: mx.array, model: nn.Module, @@ -173,7 +197,7 @@ def generate_step( prompt_cache: Optional[Any] = None, logit_bias: Optional[Dict[int, float]] = None, logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, - quantized_kv: bool = False, + quantized_kv_start: Optional[int] = None, kv_group_size: int = 64, kv_bits: int = 8, ) -> Generator[Tuple[mx.array, mx.array], None, None]: @@ -261,14 +285,13 @@ def generate_step( prompt_cache = cache.make_prompt_cache( model, max_kv_size=max_kv_size, - quantized_kv=quantized_kv, - kv_group_size=kv_group_size, - kv_bits=kv_bits, ) elif len(prompt_cache) != len(model.layers): raise ValueError("Wrong number of layers in the prompt cache.") def _step(y): + + nonlocal prompt_cache logits = model(y[None], cache=prompt_cache) logits = logits[:, -1, :] @@ -279,6 +302,10 @@ def generate_step( for processor in logits_processor: logits = processor(tokens, logits) + prompt_cache = maybe_quantize_kv_cache( + prompt_cache, quantized_kv_start, kv_group_size, kv_bits + ) + y, logprobs = sample(logits) return y, logprobs.squeeze(0)