Quantized KV Cache (#1075)

* add QuantizedKVCache

* simplify

* add tests

* single sdpa function

* fix sed

* in place

* fix tests

* support different k and v head dims
This commit is contained in:
Alex Barron 2024-10-31 16:59:52 -07:00 committed by GitHub
parent 9f34fdbda4
commit 85ffd2c96a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 411 additions and 85 deletions

View File

@ -8,7 +8,9 @@ import time
import mlx.core as mx import mlx.core as mx
from .models.cache import make_prompt_cache, save_prompt_cache from .models.cache import make_prompt_cache, save_prompt_cache
from .utils import load from .utils import load, maybe_quantize_kv_cache
DEFAULT_QUANTIZED_KV_START = 5000
def setup_arg_parser(): def setup_arg_parser():
@ -70,6 +72,26 @@ def setup_arg_parser():
required=True, required=True,
help="Message to be processed by the model ('-' reads from stdin)", help="Message to be processed by the model ('-' reads from stdin)",
) )
parser.add_argument(
"--kv-bits",
type=int,
help="Number of bits for KV cache quantization. "
"Defaults to no quantization.",
default=None,
)
parser.add_argument(
"--kv-group-size",
type=int,
help="Group size for KV cache quantization.",
default=64,
)
parser.add_argument(
"--quantized-kv-start",
help="When --kv-bits is set, start quantizing the KV cache "
"from this step onwards.",
type=int,
default=DEFAULT_QUANTIZED_KV_START,
)
return parser return parser
@ -127,6 +149,7 @@ def main():
start = time.time() start = time.time()
max_msg_len = 0 max_msg_len = 0
while y.size > 0: while y.size > 0:
model(y[:step_size][None], cache=cache) model(y[:step_size][None], cache=cache)
mx.eval([c.state for c in cache]) mx.eval([c.state for c in cache])
processed += min(y.size, step_size) processed += min(y.size, step_size)
@ -136,6 +159,11 @@ def main():
msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)" msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
max_msg_len = max(max_msg_len, len(msg)) max_msg_len = max(max_msg_len, len(msg))
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True) print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)
maybe_quantize_kv_cache(
cache, args.quantized_kv_start, args.kv_group_size, args.kv_bits
)
print() print()
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")

View File

@ -6,7 +6,7 @@ import sys
import mlx.core as mx import mlx.core as mx
from .models.cache import load_prompt_cache from .models.cache import QuantizedKVCache, load_prompt_cache
from .utils import generate, load from .utils import generate, load
DEFAULT_PROMPT = "hello" DEFAULT_PROMPT = "hello"
@ -15,6 +15,7 @@ DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0 DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0 DEFAULT_SEED = 0
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
DEFAULT_QUANTIZED_KV_START = 5000
def str2bool(string): def str2bool(string):
@ -107,6 +108,26 @@ def setup_arg_parser():
default=None, default=None,
help="A file containing saved KV caches to avoid recomputing them", help="A file containing saved KV caches to avoid recomputing them",
) )
parser.add_argument(
"--kv-bits",
type=int,
help="Number of bits for KV cache quantization. "
"Defaults to no quantization.",
default=None,
)
parser.add_argument(
"--kv-group-size",
type=int,
help="Group size for KV cache quantization.",
default=64,
)
parser.add_argument(
"--quantized-kv-start",
help="When --kv-bits is set, start quantizing the KV cache "
"from this step onwards.",
type=int,
default=DEFAULT_QUANTIZED_KV_START,
)
return parser return parser
@ -150,8 +171,18 @@ def main():
using_cache = args.prompt_cache_file is not None using_cache = args.prompt_cache_file is not None
if using_cache: if using_cache:
prompt_cache, metadata = load_prompt_cache( prompt_cache, metadata = load_prompt_cache(
args.prompt_cache_file, return_metadata=True args.prompt_cache_file,
return_metadata=True,
) )
if isinstance(prompt_cache[0], QuantizedKVCache):
if args.kv_bits is not None and args.kv_bits != prompt_cache[0].bits:
raise ValueError(
"--kv-bits does not match the kv cache loaded from --prompt-cache-file."
)
if args.kv_group_size != prompt_cache[0].group_size:
raise ValueError(
"--kv-group-size does not match the kv cache loaded from --prompt-cache-file."
)
# Building tokenizer_config # Building tokenizer_config
tokenizer_config = ( tokenizer_config = (
@ -227,6 +258,9 @@ def main():
top_p=args.top_p, top_p=args.top_p,
max_kv_size=args.max_kv_size, max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None, prompt_cache=prompt_cache if using_cache else None,
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start,
) )
if not args.verbose: if not args.verbose:
print(response) print(response)

View File

@ -5,6 +5,9 @@ from dataclasses import dataclass
from typing import Any, Optional from typing import Any, Optional
import mlx.core as mx import mlx.core as mx
from mlx.utils import tree_map
from .cache import QuantizedKVCache
@dataclass @dataclass
@ -48,3 +51,63 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
else: else:
mask = None mask = None
return mask return mask
def quantized_scaled_dot_product_attention(
queries: mx.array,
q_keys: tuple[mx.array, mx.array, mx.array],
q_values: tuple[mx.array, mx.array, mx.array],
scale: float,
mask: Optional[mx.array],
group_size: int = 64,
bits: int = 8,
) -> mx.array:
B, n_q_heads, L, D = queries.shape
n_kv_heads = q_keys[0].shape[-3]
n_repeats = n_q_heads // n_kv_heads
queries *= scale
if n_repeats > 1:
queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D))
q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values)
scores = mx.quantized_matmul(
queries, *q_keys, transpose=True, group_size=group_size, bits=bits
)
if mask is not None:
scores += mask
scores = mx.softmax(scores, axis=-1, precise=True)
out = mx.quantized_matmul(
scores, *q_values, transpose=False, group_size=group_size, bits=bits
)
if n_repeats > 1:
out = mx.reshape(out, (B, n_q_heads, L, D))
return out
def scaled_dot_product_attention(
queries,
keys,
values,
cache,
scale: float,
mask: Optional[mx.array],
) -> mx.array:
if isinstance(cache, QuantizedKVCache):
return quantized_scaled_dot_product_attention(
queries,
keys,
values,
scale=scale,
mask=mask,
group_size=cache.group_size,
bits=cache.bits,
)
else:
return mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=scale, mask=mask
)

View File

@ -4,10 +4,13 @@ from typing import Any, Dict, List, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx.utils import tree_flatten, tree_unflatten from mlx.utils import tree_flatten, tree_map, tree_unflatten
def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]: def make_prompt_cache(
model: nn.Module,
max_kv_size: Optional[int] = None,
) -> List[Any]:
""" """
Construct the model's cache for use when cgeneration. Construct the model's cache for use when cgeneration.
@ -126,6 +129,88 @@ class _BaseCache:
return False return False
class QuantizedKVCache(_BaseCache):
def __init__(self, group_size: int = 64, bits: int = 8):
self.keys = None
self.values = None
self.offset = 0
self.step = 256
self.group_size = group_size
self.bits = bits
def update_and_fetch(self, keys, values):
B, n_kv_heads, num_steps, k_head_dim = keys.shape
v_head_dim = values.shape[-1]
prev = self.offset
if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]:
el_per_int = 8 * mx.uint32.size // self.bits
new_steps = (self.step + num_steps - 1) // self.step * self.step
shape = (B, n_kv_heads, new_steps)
def init_quant(dim):
return (
mx.zeros((*shape, dim // el_per_int), dtype=mx.uint32),
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
)
def expand_quant(x):
new_x = mx.zeros((*shape, x.shape[-1]), dtype=x.dtype)
return mx.concatenate([x, new_x], axis=-2)
if self.keys is not None:
if prev % self.step != 0:
self.keys, self.values = tree_map(
lambda x: x[..., :prev, :], (self.keys, self.values)
)
self.keys, self.values = tree_map(
expand_quant, (self.keys, self.values)
)
else:
self.keys, self.values = init_quant(k_head_dim), init_quant(v_head_dim)
self.offset += num_steps
keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits)
values = mx.quantize(values, group_size=self.group_size, bits=self.bits)
for i in range(len(self.keys)):
self.keys[i][..., prev : self.offset, :] = keys[i]
self.values[i][..., prev : self.offset, :] = values[i]
return tree_map(lambda x: x[..., : self.offset, :], (self.keys, self.values))
@property
def state(self):
if self.offset == self.keys[0].shape[2]:
return self.keys, self.values
else:
return tree_map(
lambda x: x[..., : self.offset, :], (self.keys, self.values)
)
@state.setter
def state(self, v):
self.keys, self.values = v
@property
def meta_state(self):
return tuple(map(str, (self.step, self.offset, self.group_size, self.bits)))
@meta_state.setter
def meta_state(self, v):
self.step, self.offset, self.group_size, self.bits = map(int, v)
def is_trimmable(self):
return True
def trim(self, n):
n = min(self.offset, n)
self.offset -= n
return n
class KVCache(_BaseCache): class KVCache(_BaseCache):
def __init__(self): def __init__(self):
self.keys = None self.keys = None
@ -180,6 +265,16 @@ class KVCache(_BaseCache):
self.offset -= n self.offset -= n
return n return n
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
quant_cache = QuantizedKVCache(group_size=group_size, bits=bits)
quant_cache.offset = self.offset
if self.keys is not None:
quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits)
quant_cache.values = mx.quantize(
self.values, group_size=group_size, bits=bits
)
return quant_cache
class RotatingKVCache(_BaseCache): class RotatingKVCache(_BaseCache):
@ -320,6 +415,9 @@ class RotatingKVCache(_BaseCache):
self._idx -= n self._idx -= n
return n return n
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
raise NotImplementedError("RotatingKVCache Quantization NYI")
class MambaCache(_BaseCache): class MambaCache(_BaseCache):
def __init__(self): def __init__(self):

View File

@ -6,7 +6,7 @@ from typing import Any, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -93,8 +93,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -7,7 +7,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -74,8 +74,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(output) return self.out_proj(output)

View File

@ -4,7 +4,7 @@ from typing import Any, Dict, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU from .switch_layers import SwitchGLU
@ -97,8 +97,8 @@ class DeepseekAttention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU from .switch_layers import SwitchGLU
@ -235,8 +235,8 @@ class DeepseekV2Attention(nn.Module):
queries = mx.concatenate([q_nope, q_pe], axis=-1) queries = mx.concatenate([q_nope, q_pe], axis=-1)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -6,7 +6,7 @@ from typing import Any, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -79,8 +79,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -7,7 +7,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -61,8 +61,8 @@ class Attention(nn.Module):
if cache is not None: if cache is not None:
keys, values = cache.update_and_fetch(keys, values) keys, values = cache.update_and_fetch(keys, values)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -7,7 +7,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -74,8 +74,8 @@ class Attention(nn.Module):
if cache is not None: if cache is not None:
keys, values = cache.update_and_fetch(keys, values) keys, values = cache.update_and_fetch(keys, values)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.c_proj(output) return self.c_proj(output)

View File

@ -7,7 +7,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
# Based on the transformers implementation at: # Based on the transformers implementation at:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
@ -79,8 +79,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -141,8 +141,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.wo(output) return self.wo(output)

View File

@ -1,12 +1,12 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -190,9 +190,10 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -7,7 +7,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -105,8 +105,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
attn_output = mx.fast.scaled_dot_product_attention( attn_output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1) attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU from .switch_layers import SwitchGLU
@ -87,8 +87,8 @@ class MixtralAttention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -113,8 +113,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -107,8 +107,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -7,7 +7,7 @@ from typing import Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -93,8 +93,13 @@ class PhiAttention(nn.Module):
keys = self.rope(keys) keys = self.rope(keys)
scale = math.sqrt(1 / queries.shape[-1]) scale = math.sqrt(1 / queries.shape[-1])
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries.astype(mx.float32), keys, values, scale=scale, mask=mask queries.astype(mx.float32),
keys,
values,
cache=cache,
scale=scale,
mask=mask,
).astype(values.dtype) ).astype(values.dtype)
output = output.moveaxis(2, 1).reshape(B, L, -1) output = output.moveaxis(2, 1).reshape(B, L, -1)

View File

@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .su_rope import SuScaledRotaryEmbedding from .su_rope import SuScaledRotaryEmbedding
@ -107,8 +107,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -8,7 +8,7 @@ from typing import Any, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -188,8 +188,8 @@ class Attention(nn.Module):
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, scale=self.scale, mask=mask
) )
else: else:
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.dense(output) return self.dense(output)

View File

@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .su_rope import SuScaledRotaryEmbedding from .su_rope import SuScaledRotaryEmbedding
from .switch_layers import SwitchGLU from .switch_layers import SwitchGLU
@ -79,8 +79,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -8,7 +8,7 @@ from typing import Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import create_attention_mask from .base import create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchMLP from .switch_layers import SwitchMLP
@ -71,8 +71,13 @@ class RoPEAttention(nn.Module):
# Finally perform the attention computation # Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1]) scale = math.sqrt(1 / queries.shape[-1])
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries.astype(mx.float32), keys, values, scale=scale, mask=mask queries.astype(mx.float32),
keys,
values,
cache=cache,
scale=scale,
mask=mask,
).astype(values.dtype) ).astype(values.dtype)
output = output.moveaxis(2, 1).reshape(B, L, -1) output = output.moveaxis(2, 1).reshape(B, L, -1)

View File

@ -7,7 +7,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -92,10 +92,11 @@ class Attention(nn.Module):
keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1]) keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1])
values = mx.tile(values, [1, self.config.n_shared_head, 1, 1]) values = mx.tile(values, [1, self.config.n_shared_head, 1, 1])
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, queries,
keys, keys,
values, values,
cache=cache,
scale=self.scale, scale=self.scale,
mask=attention_mask, mask=attention_mask,
) )

View File

@ -5,7 +5,7 @@ from dataclasses import dataclass
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -64,8 +64,8 @@ class Attention(nn.Module):
queries = self.rotary_emb(queries) queries = self.rotary_emb(queries)
keys = self.rotary_emb(keys) keys = self.rotary_emb(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -89,8 +89,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU from .switch_layers import SwitchGLU
@ -89,8 +89,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -7,7 +7,7 @@ from typing import List, Literal, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import MambaCache, RotatingKVCache from .cache import MambaCache, RotatingKVCache
@ -263,8 +263,8 @@ class LocalAttentionBlock(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -6,7 +6,7 @@ from dataclasses import dataclass
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -120,8 +120,8 @@ class Attention(nn.Module):
# Finally perform the attention computation # Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1]) scale = math.sqrt(1 / queries.shape[-1])
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=scale, mask=mask queries, keys, values, cache=cache, scale=scale, mask=mask
).astype(values.dtype) ).astype(values.dtype)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -6,7 +6,7 @@ from typing import Any, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -64,8 +64,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -19,7 +19,7 @@ from mlx.utils import tree_flatten, tree_reduce
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
# Local imports # Local imports
from .models import base, cache from .models import cache
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import dequantize as dequantize_model from .tuner.utils import dequantize as dequantize_model
@ -159,6 +159,18 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float)
return logits return logits
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
if (
kv_bits is not None
and not isinstance(prompt_cache[0], cache.QuantizedKVCache)
and prompt_cache[0].offset > quantized_kv_start
):
for i in range(len(prompt_cache)):
prompt_cache[i] = prompt_cache[i].to_quantized(
group_size=kv_group_size, bits=kv_bits
)
def generate_step( def generate_step(
prompt: mx.array, prompt: mx.array,
model: nn.Module, model: nn.Module,
@ -173,6 +185,9 @@ def generate_step(
prompt_cache: Optional[Any] = None, prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None, logit_bias: Optional[Dict[int, float]] = None,
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
) -> Generator[Tuple[mx.array, mx.array], None, None]: ) -> Generator[Tuple[mx.array, mx.array], None, None]:
""" """
A generator producing token ids based on the given prompt from the model. A generator producing token ids based on the given prompt from the model.
@ -201,6 +216,11 @@ def generate_step(
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional): logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed A list of functions that take tokens and logits and return the processed
logits. Default: ``None``. logits. Default: ``None``.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
Yields: Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
@ -255,11 +275,15 @@ def generate_step(
# Create the KV cache for generation # Create the KV cache for generation
if prompt_cache is None: if prompt_cache is None:
prompt_cache = cache.make_prompt_cache(model, max_kv_size) prompt_cache = cache.make_prompt_cache(
model,
max_kv_size=max_kv_size,
)
elif len(prompt_cache) != len(model.layers): elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.") raise ValueError("Wrong number of layers in the prompt cache.")
def _step(y): def _step(y):
logits = model(y[None], cache=prompt_cache) logits = model(y[None], cache=prompt_cache)
logits = logits[:, -1, :] logits = logits[:, -1, :]
@ -270,6 +294,10 @@ def generate_step(
for processor in logits_processor: for processor in logits_processor:
logits = processor(tokens, logits) logits = processor(tokens, logits)
maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
)
y, logprobs = sample(logits) y, logprobs = sample(logits)
return y, logprobs.squeeze(0) return y, logprobs.squeeze(0)

View File

@ -9,6 +9,7 @@ import mlx.core as mx
from mlx_lm.models.cache import ( from mlx_lm.models.cache import (
KVCache, KVCache,
MambaCache, MambaCache,
QuantizedKVCache,
RotatingKVCache, RotatingKVCache,
load_prompt_cache, load_prompt_cache,
make_prompt_cache, make_prompt_cache,
@ -186,6 +187,18 @@ class TestPromptCache(unittest.TestCase):
num_trimmed = trim_prompt_cache(cache, 4) num_trimmed = trim_prompt_cache(cache, 4)
self.assertEqual(num_trimmed, 0) self.assertEqual(num_trimmed, 0)
cache = [QuantizedKVCache() for _ in range(2)]
for c in cache:
x = mx.random.uniform(shape=(1, 8, 10, 64))
c.update_and_fetch(x, x)
num_trimmed = trim_prompt_cache(cache, 7)
self.assertEqual(num_trimmed, 7)
# Trim more tokens than remain
num_trimmed = trim_prompt_cache(cache, 4)
self.assertEqual(num_trimmed, 3)
def test_trim_cache_with_generate(self): def test_trim_cache_with_generate(self):
model, tokenizer = load(HF_MODEL_PATH) model, tokenizer = load(HF_MODEL_PATH)
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
@ -238,6 +251,56 @@ class TestPromptCache(unittest.TestCase):
self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y)) self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y))
self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z)) self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z))
def test_save_load_quantized_cache(self):
cache = [QuantizedKVCache(bits=4, group_size=32) for _ in range(4)]
for c in cache:
x = mx.random.uniform(shape=(1, 8, 10, 32))
c.update_and_fetch(x, x)
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
save_prompt_cache(cache_file, cache)
loaded_cache = load_prompt_cache(cache_file)
self.assertTrue(loaded_cache[0].bits == cache[0].bits)
self.assertTrue(loaded_cache[0].group_size == cache[0].group_size)
self.assertTrue(len(cache), len(loaded_cache))
for c, lc in zip(cache, loaded_cache):
self.assertEqual(c.offset, lc.offset)
# Loop over quantized tuple
for i in range(3):
self.assertTrue(mx.array_equal(c.state[0][i], lc.state[0][i]))
self.assertTrue(mx.array_equal(c.state[1][i], lc.state[1][i]))
# Test with metadata
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
metadata = {"a": "b", "c": "d"}
save_prompt_cache(cache_file, cache, metadata)
_, loaded_metadata = load_prompt_cache(cache_file, return_metadata=True)
self.assertEqual(metadata, loaded_metadata)
def test_cache_to_quantized(self):
model, tokenizer = load(HF_MODEL_PATH)
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
results = zip(range(4), generate_step(prompt, model))
toks, all_logits = zip(*(r[1] for r in results))
prompt_cache = make_prompt_cache(model)
i = 0
for _, (tok, logits) in zip(
range(2), generate_step(prompt, model, prompt_cache=prompt_cache)
):
self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i]))
i += 1
prompt_cache = [c.to_quantized(bits=8, group_size=32) for c in prompt_cache]
for _, (tok, logits) in zip(
range(1),
generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache),
):
i += 1
self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i], rtol=1e-2))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()