From 85ffd2c96a45a8cb900f95a2ded61d858d673399 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Thu, 31 Oct 2024 16:59:52 -0700 Subject: [PATCH] Quantized KV Cache (#1075) * add QuantizedKVCache * simplify * add tests * single sdpa function * fix sed * in place * fix tests * support different k and v head dims --- llms/mlx_lm/cache_prompt.py | 30 +++++++- llms/mlx_lm/generate.py | 38 +++++++++- llms/mlx_lm/models/base.py | 63 ++++++++++++++++ llms/mlx_lm/models/cache.py | 102 +++++++++++++++++++++++++- llms/mlx_lm/models/cohere.py | 6 +- llms/mlx_lm/models/dbrx.py | 6 +- llms/mlx_lm/models/deepseek.py | 6 +- llms/mlx_lm/models/deepseek_v2.py | 6 +- llms/mlx_lm/models/gemma.py | 6 +- llms/mlx_lm/models/gpt2.py | 6 +- llms/mlx_lm/models/gpt_bigcode.py | 6 +- llms/mlx_lm/models/gpt_neox.py | 6 +- llms/mlx_lm/models/internlm2.py | 6 +- llms/mlx_lm/models/llama.py | 9 ++- llms/mlx_lm/models/minicpm.py | 6 +- llms/mlx_lm/models/mixtral.py | 6 +- llms/mlx_lm/models/nemotron.py | 6 +- llms/mlx_lm/models/openelm.py | 6 +- llms/mlx_lm/models/phi.py | 11 ++- llms/mlx_lm/models/phi3.py | 6 +- llms/mlx_lm/models/phi3small.py | 6 +- llms/mlx_lm/models/phimoe.py | 6 +- llms/mlx_lm/models/phixtral.py | 11 ++- llms/mlx_lm/models/plamo.py | 5 +- llms/mlx_lm/models/qwen.py | 6 +- llms/mlx_lm/models/qwen2.py | 6 +- llms/mlx_lm/models/qwen2_moe.py | 6 +- llms/mlx_lm/models/recurrent_gemma.py | 6 +- llms/mlx_lm/models/stablelm.py | 6 +- llms/mlx_lm/models/starcoder2.py | 6 +- llms/mlx_lm/utils.py | 32 +++++++- llms/tests/test_prompt_cache.py | 63 ++++++++++++++++ 32 files changed, 411 insertions(+), 85 deletions(-) diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index 04e75a3e..7bb06411 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -8,7 +8,9 @@ import time import mlx.core as mx from .models.cache import make_prompt_cache, save_prompt_cache -from .utils import load +from .utils import load, maybe_quantize_kv_cache + +DEFAULT_QUANTIZED_KV_START = 5000 def setup_arg_parser(): @@ -70,6 +72,26 @@ def setup_arg_parser(): required=True, help="Message to be processed by the model ('-' reads from stdin)", ) + parser.add_argument( + "--kv-bits", + type=int, + help="Number of bits for KV cache quantization. " + "Defaults to no quantization.", + default=None, + ) + parser.add_argument( + "--kv-group-size", + type=int, + help="Group size for KV cache quantization.", + default=64, + ) + parser.add_argument( + "--quantized-kv-start", + help="When --kv-bits is set, start quantizing the KV cache " + "from this step onwards.", + type=int, + default=DEFAULT_QUANTIZED_KV_START, + ) return parser @@ -127,6 +149,7 @@ def main(): start = time.time() max_msg_len = 0 while y.size > 0: + model(y[:step_size][None], cache=cache) mx.eval([c.state for c in cache]) processed += min(y.size, step_size) @@ -136,6 +159,11 @@ def main(): msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)" max_msg_len = max(max_msg_len, len(msg)) print(msg + " " * (max_msg_len - len(msg)), end="", flush=True) + + maybe_quantize_kv_cache( + cache, args.quantized_kv_start, args.kv_group_size, args.kv_bits + ) + print() print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 0bf98ab2..0355ca29 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -6,7 +6,7 @@ import sys import mlx.core as mx -from .models.cache import load_prompt_cache +from .models.cache import QuantizedKVCache, load_prompt_cache from .utils import generate, load DEFAULT_PROMPT = "hello" @@ -15,6 +15,7 @@ DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 DEFAULT_SEED = 0 DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" +DEFAULT_QUANTIZED_KV_START = 5000 def str2bool(string): @@ -107,6 +108,26 @@ def setup_arg_parser(): default=None, help="A file containing saved KV caches to avoid recomputing them", ) + parser.add_argument( + "--kv-bits", + type=int, + help="Number of bits for KV cache quantization. " + "Defaults to no quantization.", + default=None, + ) + parser.add_argument( + "--kv-group-size", + type=int, + help="Group size for KV cache quantization.", + default=64, + ) + parser.add_argument( + "--quantized-kv-start", + help="When --kv-bits is set, start quantizing the KV cache " + "from this step onwards.", + type=int, + default=DEFAULT_QUANTIZED_KV_START, + ) return parser @@ -150,8 +171,18 @@ def main(): using_cache = args.prompt_cache_file is not None if using_cache: prompt_cache, metadata = load_prompt_cache( - args.prompt_cache_file, return_metadata=True + args.prompt_cache_file, + return_metadata=True, ) + if isinstance(prompt_cache[0], QuantizedKVCache): + if args.kv_bits is not None and args.kv_bits != prompt_cache[0].bits: + raise ValueError( + "--kv-bits does not match the kv cache loaded from --prompt-cache-file." + ) + if args.kv_group_size != prompt_cache[0].group_size: + raise ValueError( + "--kv-group-size does not match the kv cache loaded from --prompt-cache-file." + ) # Building tokenizer_config tokenizer_config = ( @@ -227,6 +258,9 @@ def main(): top_p=args.top_p, max_kv_size=args.max_kv_size, prompt_cache=prompt_cache if using_cache else None, + kv_bits=args.kv_bits, + kv_group_size=args.kv_group_size, + quantized_kv_start=args.quantized_kv_start, ) if not args.verbose: print(response) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 3628a808..cda41c79 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -5,6 +5,9 @@ from dataclasses import dataclass from typing import Any, Optional import mlx.core as mx +from mlx.utils import tree_map + +from .cache import QuantizedKVCache @dataclass @@ -48,3 +51,63 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None): else: mask = None return mask + + +def quantized_scaled_dot_product_attention( + queries: mx.array, + q_keys: tuple[mx.array, mx.array, mx.array], + q_values: tuple[mx.array, mx.array, mx.array], + scale: float, + mask: Optional[mx.array], + group_size: int = 64, + bits: int = 8, +) -> mx.array: + B, n_q_heads, L, D = queries.shape + n_kv_heads = q_keys[0].shape[-3] + n_repeats = n_q_heads // n_kv_heads + + queries *= scale + + if n_repeats > 1: + queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D)) + q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys) + q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values) + + scores = mx.quantized_matmul( + queries, *q_keys, transpose=True, group_size=group_size, bits=bits + ) + if mask is not None: + scores += mask + scores = mx.softmax(scores, axis=-1, precise=True) + out = mx.quantized_matmul( + scores, *q_values, transpose=False, group_size=group_size, bits=bits + ) + + if n_repeats > 1: + out = mx.reshape(out, (B, n_q_heads, L, D)) + + return out + + +def scaled_dot_product_attention( + queries, + keys, + values, + cache, + scale: float, + mask: Optional[mx.array], +) -> mx.array: + if isinstance(cache, QuantizedKVCache): + return quantized_scaled_dot_product_attention( + queries, + keys, + values, + scale=scale, + mask=mask, + group_size=cache.group_size, + bits=cache.bits, + ) + else: + return mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=scale, mask=mask + ) diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index a6a56e0a..1cd5289d 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -4,10 +4,13 @@ from typing import Any, Dict, List, Optional import mlx.core as mx import mlx.nn as nn -from mlx.utils import tree_flatten, tree_unflatten +from mlx.utils import tree_flatten, tree_map, tree_unflatten -def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]: +def make_prompt_cache( + model: nn.Module, + max_kv_size: Optional[int] = None, +) -> List[Any]: """ Construct the model's cache for use when cgeneration. @@ -126,6 +129,88 @@ class _BaseCache: return False +class QuantizedKVCache(_BaseCache): + def __init__(self, group_size: int = 64, bits: int = 8): + self.keys = None + self.values = None + self.offset = 0 + self.step = 256 + self.group_size = group_size + self.bits = bits + + def update_and_fetch(self, keys, values): + B, n_kv_heads, num_steps, k_head_dim = keys.shape + v_head_dim = values.shape[-1] + prev = self.offset + + if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]: + el_per_int = 8 * mx.uint32.size // self.bits + new_steps = (self.step + num_steps - 1) // self.step * self.step + shape = (B, n_kv_heads, new_steps) + + def init_quant(dim): + return ( + mx.zeros((*shape, dim // el_per_int), dtype=mx.uint32), + mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype), + mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype), + ) + + def expand_quant(x): + new_x = mx.zeros((*shape, x.shape[-1]), dtype=x.dtype) + return mx.concatenate([x, new_x], axis=-2) + + if self.keys is not None: + if prev % self.step != 0: + self.keys, self.values = tree_map( + lambda x: x[..., :prev, :], (self.keys, self.values) + ) + + self.keys, self.values = tree_map( + expand_quant, (self.keys, self.values) + ) + else: + self.keys, self.values = init_quant(k_head_dim), init_quant(v_head_dim) + + self.offset += num_steps + + keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits) + values = mx.quantize(values, group_size=self.group_size, bits=self.bits) + for i in range(len(self.keys)): + self.keys[i][..., prev : self.offset, :] = keys[i] + self.values[i][..., prev : self.offset, :] = values[i] + + return tree_map(lambda x: x[..., : self.offset, :], (self.keys, self.values)) + + @property + def state(self): + if self.offset == self.keys[0].shape[2]: + return self.keys, self.values + else: + return tree_map( + lambda x: x[..., : self.offset, :], (self.keys, self.values) + ) + + @state.setter + def state(self, v): + self.keys, self.values = v + + @property + def meta_state(self): + return tuple(map(str, (self.step, self.offset, self.group_size, self.bits))) + + @meta_state.setter + def meta_state(self, v): + self.step, self.offset, self.group_size, self.bits = map(int, v) + + def is_trimmable(self): + return True + + def trim(self, n): + n = min(self.offset, n) + self.offset -= n + return n + + class KVCache(_BaseCache): def __init__(self): self.keys = None @@ -180,6 +265,16 @@ class KVCache(_BaseCache): self.offset -= n return n + def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: + quant_cache = QuantizedKVCache(group_size=group_size, bits=bits) + quant_cache.offset = self.offset + if self.keys is not None: + quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits) + quant_cache.values = mx.quantize( + self.values, group_size=group_size, bits=bits + ) + return quant_cache + class RotatingKVCache(_BaseCache): @@ -320,6 +415,9 @@ class RotatingKVCache(_BaseCache): self._idx -= n return n + def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: + raise NotImplementedError("RotatingKVCache Quantization NYI") + class MambaCache(_BaseCache): def __init__(self): diff --git a/llms/mlx_lm/models/cohere.py b/llms/mlx_lm/models/cohere.py index 057c816d..7e002b0c 100644 --- a/llms/mlx_lm/models/cohere.py +++ b/llms/mlx_lm/models/cohere.py @@ -6,7 +6,7 @@ from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -93,8 +93,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/dbrx.py b/llms/mlx_lm/models/dbrx.py index 3b7e83d7..7be274cc 100644 --- a/llms/mlx_lm/models/dbrx.py +++ b/llms/mlx_lm/models/dbrx.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -74,8 +74,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.out_proj(output) diff --git a/llms/mlx_lm/models/deepseek.py b/llms/mlx_lm/models/deepseek.py index 03cb3b1a..b7b24dba 100644 --- a/llms/mlx_lm/models/deepseek.py +++ b/llms/mlx_lm/models/deepseek.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchGLU @@ -97,8 +97,8 @@ class DeepseekAttention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index bb3e5184..444813b9 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchGLU @@ -235,8 +235,8 @@ class DeepseekV2Attention(nn.Module): queries = mx.concatenate([q_nope, q_pe], axis=-1) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py index 61de781e..3f384c3f 100644 --- a/llms/mlx_lm/models/gemma.py +++ b/llms/mlx_lm/models/gemma.py @@ -6,7 +6,7 @@ from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -79,8 +79,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/gpt2.py b/llms/mlx_lm/models/gpt2.py index 97d9a8ff..52076a34 100644 --- a/llms/mlx_lm/models/gpt2.py +++ b/llms/mlx_lm/models/gpt2.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -61,8 +61,8 @@ class Attention(nn.Module): if cache is not None: keys, values = cache.update_and_fetch(keys, values) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py index 068046ea..23e86e20 100644 --- a/llms/mlx_lm/models/gpt_bigcode.py +++ b/llms/mlx_lm/models/gpt_bigcode.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -74,8 +74,8 @@ class Attention(nn.Module): if cache is not None: keys, values = cache.update_and_fetch(keys, values) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.c_proj(output) diff --git a/llms/mlx_lm/models/gpt_neox.py b/llms/mlx_lm/models/gpt_neox.py index 9f662491..ccb0b28b 100644 --- a/llms/mlx_lm/models/gpt_neox.py +++ b/llms/mlx_lm/models/gpt_neox.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention # Based on the transformers implementation at: # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -79,8 +79,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/internlm2.py b/llms/mlx_lm/models/internlm2.py index 5264cb57..f5ce057e 100644 --- a/llms/mlx_lm/models/internlm2.py +++ b/llms/mlx_lm/models/internlm2.py @@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -141,8 +141,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.wo(output) diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 7da6b333..438278e5 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -1,12 +1,12 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -190,9 +190,10 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py index 4ac3c3b4..907beb2a 100644 --- a/llms/mlx_lm/models/minicpm.py +++ b/llms/mlx_lm/models/minicpm.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -105,8 +105,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - attn_output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + attn_output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index 20944fe3..dd94d1f4 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchGLU @@ -87,8 +87,8 @@ class MixtralAttention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/nemotron.py b/llms/mlx_lm/models/nemotron.py index 3ea06e27..f73c0277 100644 --- a/llms/mlx_lm/models/nemotron.py +++ b/llms/mlx_lm/models/nemotron.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -113,8 +113,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/openelm.py b/llms/mlx_lm/models/openelm.py index 090e21c6..408802f4 100644 --- a/llms/mlx_lm/models/openelm.py +++ b/llms/mlx_lm/models/openelm.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -107,8 +107,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index 56b383b2..510025ea 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -7,7 +7,7 @@ from typing import Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -93,8 +93,13 @@ class PhiAttention(nn.Module): keys = self.rope(keys) scale = math.sqrt(1 / queries.shape[-1]) - output = mx.fast.scaled_dot_product_attention( - queries.astype(mx.float32), keys, values, scale=scale, mask=mask + output = scaled_dot_product_attention( + queries.astype(mx.float32), + keys, + values, + cache=cache, + scale=scale, + mask=mask, ).astype(values.dtype) output = output.moveaxis(2, 1).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index 9ef76f04..ee6efc49 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .su_rope import SuScaledRotaryEmbedding @@ -107,8 +107,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/phi3small.py b/llms/mlx_lm/models/phi3small.py index 6b0759b4..53e1a638 100644 --- a/llms/mlx_lm/models/phi3small.py +++ b/llms/mlx_lm/models/phi3small.py @@ -8,7 +8,7 @@ from typing import Any, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -188,8 +188,8 @@ class Attention(nn.Module): queries, keys, values, scale=self.scale, mask=mask ) else: - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.dense(output) diff --git a/llms/mlx_lm/models/phimoe.py b/llms/mlx_lm/models/phimoe.py index ca20a388..f42a6dd0 100644 --- a/llms/mlx_lm/models/phimoe.py +++ b/llms/mlx_lm/models/phimoe.py @@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .su_rope import SuScaledRotaryEmbedding from .switch_layers import SwitchGLU @@ -79,8 +79,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index 865d0d8e..42d647b0 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -8,7 +8,7 @@ from typing import Tuple import mlx.core as mx import mlx.nn as nn -from .base import create_attention_mask +from .base import create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchMLP @@ -71,8 +71,13 @@ class RoPEAttention(nn.Module): # Finally perform the attention computation scale = math.sqrt(1 / queries.shape[-1]) - output = mx.fast.scaled_dot_product_attention( - queries.astype(mx.float32), keys, values, scale=scale, mask=mask + output = scaled_dot_product_attention( + queries.astype(mx.float32), + keys, + values, + cache=cache, + scale=scale, + mask=mask, ).astype(values.dtype) output = output.moveaxis(2, 1).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index b0fd1a6c..c8e5bf50 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -92,10 +92,11 @@ class Attention(nn.Module): keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1]) values = mx.tile(values, [1, self.config.n_shared_head, 1, 1]) - output = mx.fast.scaled_dot_product_attention( + output = scaled_dot_product_attention( queries, keys, values, + cache=cache, scale=self.scale, mask=attention_mask, ) diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py index 2b69d5ec..8145a890 100644 --- a/llms/mlx_lm/models/qwen.py +++ b/llms/mlx_lm/models/qwen.py @@ -5,7 +5,7 @@ from dataclasses import dataclass import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -64,8 +64,8 @@ class Attention(nn.Module): queries = self.rotary_emb(queries) keys = self.rotary_emb(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index 4e7858de..fac59d78 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -89,8 +89,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index d199116f..167fc5dd 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchGLU @@ -89,8 +89,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py index 5595d311..49e4bb8f 100644 --- a/llms/mlx_lm/models/recurrent_gemma.py +++ b/llms/mlx_lm/models/recurrent_gemma.py @@ -7,7 +7,7 @@ from typing import List, Literal, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .cache import MambaCache, RotatingKVCache @@ -263,8 +263,8 @@ class LocalAttentionBlock(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py index 11202b02..482bb324 100644 --- a/llms/mlx_lm/models/stablelm.py +++ b/llms/mlx_lm/models/stablelm.py @@ -6,7 +6,7 @@ from dataclasses import dataclass import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -120,8 +120,8 @@ class Attention(nn.Module): # Finally perform the attention computation scale = math.sqrt(1 / queries.shape[-1]) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=scale, mask=mask ).astype(values.dtype) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index ce0a2ec5..d7e626f2 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -6,7 +6,7 @@ from typing import Any, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -64,8 +64,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 5b437c98..06784f10 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -19,7 +19,7 @@ from mlx.utils import tree_flatten, tree_reduce from transformers import PreTrainedTokenizer # Local imports -from .models import base, cache +from .models import cache from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import dequantize as dequantize_model @@ -159,6 +159,18 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float) return logits +def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits): + if ( + kv_bits is not None + and not isinstance(prompt_cache[0], cache.QuantizedKVCache) + and prompt_cache[0].offset > quantized_kv_start + ): + for i in range(len(prompt_cache)): + prompt_cache[i] = prompt_cache[i].to_quantized( + group_size=kv_group_size, bits=kv_bits + ) + + def generate_step( prompt: mx.array, model: nn.Module, @@ -173,6 +185,9 @@ def generate_step( prompt_cache: Optional[Any] = None, logit_bias: Optional[Dict[int, float]] = None, logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + kv_bits: Optional[int] = None, + kv_group_size: int = 64, + quantized_kv_start: int = 0, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -201,6 +216,11 @@ def generate_step( logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional): A list of functions that take tokens and logits and return the processed logits. Default: ``None``. + kv_bits (int, optional): Number of bits to use for KV cache quantization. + None implies no cache quantization. Default: ``None``. + kv_group_size (int): Group size for KV cache quantization. Default: ``64``. + quantized_kv_start (int): Step to begin using a quantized KV cache. + when ``kv_bits`` is non-None. Default: ``0``. Yields: Generator[Tuple[mx.array, mx.array], None, None]: A generator producing @@ -255,11 +275,15 @@ def generate_step( # Create the KV cache for generation if prompt_cache is None: - prompt_cache = cache.make_prompt_cache(model, max_kv_size) + prompt_cache = cache.make_prompt_cache( + model, + max_kv_size=max_kv_size, + ) elif len(prompt_cache) != len(model.layers): raise ValueError("Wrong number of layers in the prompt cache.") def _step(y): + logits = model(y[None], cache=prompt_cache) logits = logits[:, -1, :] @@ -270,6 +294,10 @@ def generate_step( for processor in logits_processor: logits = processor(tokens, logits) + maybe_quantize_kv_cache( + prompt_cache, quantized_kv_start, kv_group_size, kv_bits + ) + y, logprobs = sample(logits) return y, logprobs.squeeze(0) diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py index 64cd9486..1e57bd86 100644 --- a/llms/tests/test_prompt_cache.py +++ b/llms/tests/test_prompt_cache.py @@ -9,6 +9,7 @@ import mlx.core as mx from mlx_lm.models.cache import ( KVCache, MambaCache, + QuantizedKVCache, RotatingKVCache, load_prompt_cache, make_prompt_cache, @@ -186,6 +187,18 @@ class TestPromptCache(unittest.TestCase): num_trimmed = trim_prompt_cache(cache, 4) self.assertEqual(num_trimmed, 0) + cache = [QuantizedKVCache() for _ in range(2)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 64)) + c.update_and_fetch(x, x) + + num_trimmed = trim_prompt_cache(cache, 7) + self.assertEqual(num_trimmed, 7) + + # Trim more tokens than remain + num_trimmed = trim_prompt_cache(cache, 4) + self.assertEqual(num_trimmed, 3) + def test_trim_cache_with_generate(self): model, tokenizer = load(HF_MODEL_PATH) prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] @@ -238,6 +251,56 @@ class TestPromptCache(unittest.TestCase): self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y)) self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z)) + def test_save_load_quantized_cache(self): + cache = [QuantizedKVCache(bits=4, group_size=32) for _ in range(4)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 32)) + c.update_and_fetch(x, x) + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + save_prompt_cache(cache_file, cache) + loaded_cache = load_prompt_cache(cache_file) + self.assertTrue(loaded_cache[0].bits == cache[0].bits) + self.assertTrue(loaded_cache[0].group_size == cache[0].group_size) + self.assertTrue(len(cache), len(loaded_cache)) + for c, lc in zip(cache, loaded_cache): + self.assertEqual(c.offset, lc.offset) + # Loop over quantized tuple + for i in range(3): + self.assertTrue(mx.array_equal(c.state[0][i], lc.state[0][i])) + self.assertTrue(mx.array_equal(c.state[1][i], lc.state[1][i])) + + # Test with metadata + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + metadata = {"a": "b", "c": "d"} + save_prompt_cache(cache_file, cache, metadata) + _, loaded_metadata = load_prompt_cache(cache_file, return_metadata=True) + self.assertEqual(metadata, loaded_metadata) + + def test_cache_to_quantized(self): + model, tokenizer = load(HF_MODEL_PATH) + prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] + results = zip(range(4), generate_step(prompt, model)) + toks, all_logits = zip(*(r[1] for r in results)) + + prompt_cache = make_prompt_cache(model) + i = 0 + for _, (tok, logits) in zip( + range(2), generate_step(prompt, model, prompt_cache=prompt_cache) + ): + self.assertEqual(tok, toks[i]) + self.assertTrue(mx.allclose(logits, all_logits[i])) + i += 1 + + prompt_cache = [c.to_quantized(bits=8, group_size=32) for c in prompt_cache] + + for _, (tok, logits) in zip( + range(1), + generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache), + ): + i += 1 + self.assertEqual(tok, toks[i]) + self.assertTrue(mx.allclose(logits, all_logits[i], rtol=1e-2)) + if __name__ == "__main__": unittest.main()