This commit is contained in:
Alex Barron 2024-10-28 16:03:43 -07:00
parent 48655a7f83
commit 37a3723823
6 changed files with 197 additions and 90 deletions

View File

@ -8,7 +8,13 @@ import time
import mlx.core as mx import mlx.core as mx
from .models.cache import make_prompt_cache, save_prompt_cache from .models.cache import make_prompt_cache, save_prompt_cache
from .utils import load from .utils import (
DEFAULT_KV_BITS,
DEFAULT_KV_GROUP_SIZE,
check_quantized_kv_args,
load,
maybe_quantize_kv_cache,
)
def setup_arg_parser(): def setup_arg_parser():
@ -70,6 +76,24 @@ def setup_arg_parser():
required=True, required=True,
help="Message to be processed by the model ('-' reads from stdin)", help="Message to be processed by the model ('-' reads from stdin)",
) )
parser.add_argument(
"--quantized-kv-start",
help="Use a quantized KV cache from this step onwards.",
type=int,
default=None,
)
parser.add_argument(
"--kv-group-size",
type=int,
help="Group size for kv cache quantization.",
default=DEFAULT_KV_GROUP_SIZE,
)
parser.add_argument(
"--kv-bits",
type=int,
help="Number of bits for kv cache quantization.",
default=DEFAULT_KV_BITS,
)
return parser return parser
@ -93,6 +117,8 @@ def main():
args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt
check_quantized_kv_args(args.quantized_kv_start, args.kv_group_size, args.kv_bits)
if args.use_default_chat_template: if args.use_default_chat_template:
if tokenizer.chat_template is None: if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template tokenizer.chat_template = tokenizer.default_chat_template
@ -127,6 +153,7 @@ def main():
start = time.time() start = time.time()
max_msg_len = 0 max_msg_len = 0
while y.size > 0: while y.size > 0:
model(y[:step_size][None], cache=cache) model(y[:step_size][None], cache=cache)
mx.eval([c.state for c in cache]) mx.eval([c.state for c in cache])
processed += min(y.size, step_size) processed += min(y.size, step_size)
@ -136,6 +163,11 @@ def main():
msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)" msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
max_msg_len = max(max_msg_len, len(msg)) max_msg_len = max(max_msg_len, len(msg))
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True) print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)
cache = maybe_quantize_kv_cache(
cache, args.quantized_kv_start, args.kv_group_size, args.kv_bits
)
print() print()
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")

View File

@ -6,8 +6,8 @@ import sys
import mlx.core as mx import mlx.core as mx
from .models.cache import load_prompt_cache from .models.cache import QuantizedKVCache, load_prompt_cache
from .utils import generate, load from .utils import check_quantized_kv_args, generate, load
DEFAULT_PROMPT = "hello" DEFAULT_PROMPT = "hello"
DEFAULT_MAX_TOKENS = 100 DEFAULT_MAX_TOKENS = 100
@ -108,20 +108,23 @@ def setup_arg_parser():
help="A file containing saved KV caches to avoid recomputing them", help="A file containing saved KV caches to avoid recomputing them",
) )
parser.add_argument( parser.add_argument(
"--quantized-kv", "--quantized-kv-start",
help="Whether to quantize the KV cache.", help="Use a quantized KV cache from this step onwards.",
action="store_true", type=int,
default=None,
) )
parser.add_argument( parser.add_argument(
"--kv-group-size", "--kv-group-size",
type=int, type=int,
help="Group size for kv cache quantization.", help="Group size for kv cache quantization. "
"--quantized-kv-start must be provided to have an effect.",
default=64, default=64,
) )
parser.add_argument( parser.add_argument(
"--kv-bits", "--kv-bits",
type=int, type=int,
help="Number of bits for kv cache quantization.", help="Number of bits for kv cache quantization. "
"--quantized-kv-start must be provided to have an effect.",
default=8, default=8,
) )
return parser return parser
@ -169,10 +172,14 @@ def main():
prompt_cache, metadata = load_prompt_cache( prompt_cache, metadata = load_prompt_cache(
args.prompt_cache_file, args.prompt_cache_file,
return_metadata=True, return_metadata=True,
quantized_kv=args.quantized_kv,
kv_group_size=args.kv_group_size,
kv_bits=args.kv_bits,
) )
if args.quantized_kv_start and isinstance(prompt_cache[0], QuantizedKVCache):
raise ValueError(
"Specified `--quantized-kv-start` but cache from "
"`--prompt-cache-file` is already quantized."
)
check_quantized_kv_args(args.quantized_kv_start, args.kv_group_size, args.kv_bits)
# Building tokenizer_config # Building tokenizer_config
tokenizer_config = ( tokenizer_config = (
@ -248,7 +255,7 @@ def main():
top_p=args.top_p, top_p=args.top_p,
max_kv_size=args.max_kv_size, max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None, prompt_cache=prompt_cache if using_cache else None,
quantized_kv=args.quantized_kv, quantized_kv_start=args.quantized_kv_start,
kv_group_size=args.kv_group_size, kv_group_size=args.kv_group_size,
kv_bits=args.kv_bits, kv_bits=args.kv_bits,
) )

View File

@ -4,15 +4,12 @@ from typing import Any, Dict, List, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx.utils import tree_flatten, tree_unflatten from mlx.utils import tree_flatten, tree_map, tree_unflatten
def make_prompt_cache( def make_prompt_cache(
model: nn.Module, model: nn.Module,
max_kv_size: Optional[int] = None, max_kv_size: Optional[int] = None,
quantized_kv: bool = False,
kv_group_size: int = 64,
kv_bits: int = 8,
) -> List[Any]: ) -> List[Any]:
""" """
Construct the model's cache for use when cgeneration. Construct the model's cache for use when cgeneration.
@ -30,12 +27,7 @@ def make_prompt_cache(
return model.make_cache() return model.make_cache()
num_layers = len(model.layers) num_layers = len(model.layers)
if quantized_kv: if max_kv_size is not None:
return [
QuantizedKVCache(group_size=kv_group_size, bits=kv_bits)
for _ in range(num_layers)
]
elif max_kv_size is not None:
return [ return [
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers) RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
] ]
@ -62,9 +54,7 @@ def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str]
mx.save_safetensors(file_name, cache_data, cache_metadata) mx.save_safetensors(file_name, cache_data, cache_metadata)
def load_prompt_cache( def load_prompt_cache(file_name, return_metadata=False):
file_name, return_metadata=False, quantized_kv=False, kv_group_size=64, kv_bits=8
):
""" """
Load a prompt cache from a file. Load a prompt cache from a file.
@ -85,8 +75,6 @@ def load_prompt_cache(
for c, state, meta_state in zip(cache, arrays, info): for c, state, meta_state in zip(cache, arrays, info):
c.state = state c.state = state
c.meta_state = meta_state c.meta_state = meta_state
if quantized_kv:
cache = [c.to_quantized(group_size=kv_group_size, bits=kv_bits) for c in cache]
if return_metadata: if return_metadata:
return cache, metadata return cache, metadata
return cache return cache
@ -141,8 +129,44 @@ class _BaseCache:
return False return False
def quantized_scaled_dot_product_attention(
queries: mx.array,
q_keys: tuple[mx.array, mx.array, mx.array],
q_values: tuple[mx.array, mx.array, mx.array],
scale: float,
mask: Optional[mx.array],
group_size: int = 64,
bits: int = 8,
) -> mx.array:
B, n_q_heads, L, D = queries.shape
n_kv_heads = q_keys[0].shape[-3]
n_repeats = n_q_heads // n_kv_heads
queries *= scale
if n_repeats > 1:
queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D))
q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values)
scores = mx.quantized_matmul(
queries, *q_keys, transpose=True, group_size=group_size, bits=bits
)
if mask is not None:
scores += mask
scores = mx.softmax(scores, axis=-1, precise=True)
out = mx.quantized_matmul(
scores, *q_values, transpose=False, group_size=group_size, bits=bits
)
if n_repeats > 1:
out = mx.reshape(out, (B, n_q_heads, L, D))
return out
class QuantizedKVCache(_BaseCache): class QuantizedKVCache(_BaseCache):
def __init__(self, group_size: int = 64, bits: int = 4): def __init__(self, group_size: int = 64, bits: int = 8):
self.keys = None self.keys = None
self.values = None self.values = None
self.offset = 0 self.offset = 0
@ -154,71 +178,65 @@ class QuantizedKVCache(_BaseCache):
B, n_kv_heads, num_steps, k_head_dim = keys.shape B, n_kv_heads, num_steps, k_head_dim = keys.shape
prev = self.offset prev = self.offset
if self.keys is None or (prev + num_steps) > self.keys[0].shape[2]: if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]:
el_per_int = 8 * mx.uint32.size // self.bits el_per_int = 8 * mx.uint32.size // self.bits
n_steps = (self.step + keys[0].shape[2] - 1) // self.step new_steps = (self.step + num_steps - 1) // self.step * self.step
shape = (B, n_kv_heads, new_steps, k_head_dim // el_per_int)
group_shape = (B, n_kv_heads, new_steps, k_head_dim // self.group_size)
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim // el_per_int) def init_quant():
scales_dim = k_head_dim // self.group_size return (
k_scale_shape = k_shape[:-1] + (scales_dim,) mx.zeros(shape, dtype=mx.uint32),
v_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim // el_per_int) mx.zeros(group_shape, dtype=keys.dtype),
mx.zeros(group_shape, dtype=keys.dtype),
)
scale_bias_init = lambda: mx.zeros(k_scale_shape, keys.dtype) def expand_quant(x):
new_k = (mx.zeros(k_shape, mx.uint32), scale_bias_init(), scale_bias_init()) new_x = mx.zeros((B, n_kv_heads, new_steps, x.shape[-1]), dtype=x.dtype)
new_v = (mx.zeros(v_shape, mx.uint32), scale_bias_init(), scale_bias_init()) return mx.concatenate([x, new_x], axis=-2)
if self.keys is not None: if self.keys is not None:
if prev % self.step != 0: if prev % self.step != 0:
self.keys = tuple(x[..., :prev, :] for x in self.keys) self.keys, self.values = tree_map(
self.values = tuple(x[..., :prev, :] for x in self.values) lambda x: x[..., :prev, :], (self.keys, self.values)
self.keys = tuple( )
mx.concatenate([self.keys[i], new_k[i]], axis=2) for i in range(3)
) self.keys, self.values = tree_map(
self.values = tuple( expand_quant, (self.keys, self.values)
mx.concatenate([self.values[i], new_v[i]], axis=2) for i in range(3)
) )
else: else:
self.keys, self.values = new_k, new_v self.keys, self.values = init_quant(), init_quant()
self.offset += num_steps self.offset += num_steps
if num_steps > 1: keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits)
keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits) values = mx.quantize(values, group_size=self.group_size, bits=self.bits)
values = mx.quantize(values, group_size=self.group_size, bits=self.bits) for i in range(len(self.keys)):
for i in range(len(self.keys)): self.keys[i][..., prev : self.offset, :] = keys[i]
self.keys[i][..., prev : self.offset, :] = keys[i] self.values[i][..., prev : self.offset, :] = values[i]
self.values[i][..., prev : self.offset, :] = values[i]
else: return tree_map(lambda x: x[..., : self.offset, :], (self.keys, self.values))
outputs = mx.fast.quantized_kv_update(
keys,
values,
*self.keys,
*self.values,
prev,
group_size=self.group_size,
bits=self.bits
)
self.keys = outputs[:3]
self.values = outputs[3:]
return (
tuple(x[..., : self.offset, :] for x in self.keys),
tuple(x[..., : self.offset, :] for x in self.values),
)
@property
def state(self): def state(self):
return self.keys, self.values if self.offset == self.keys[0].shape[2]:
return self.keys, self.values
else:
return tree_map(
lambda x: x[..., : self.offset, :], (self.keys, self.values)
)
@classmethod @state.setter
def from_cache( def state(self, v):
cls, cache: _BaseCache, group_size: int = 64, bits: int = 4 self.keys, self.values = v
) -> "QuantizedKVCache":
quant_cache = cls(group_size=group_size, bits=bits) @property
quant_cache.offset = cache.offset def meta_state(self):
quant_cache.keys = mx.quantize(cache.keys, group_size=group_size, bits=bits) return tuple(map(str, (self.step, self.offset, self.group_size, self.bits)))
quant_cache.values = mx.quantize(cache.values, group_size=group_size, bits=bits)
return quant_cache @meta_state.setter
def meta_state(self, v):
self.step, self.offset, self.group_size, self.bits = map(int, v)
class KVCache(_BaseCache): class KVCache(_BaseCache):
@ -276,7 +294,11 @@ class KVCache(_BaseCache):
return n return n
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
return QuantizedKVCache.from_cache(self, group_size=group_size, bits=bits) quant_cache = QuantizedKVCache(group_size=group_size, bits=bits)
quant_cache.offset = self.offset
quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits)
quant_cache.values = mx.quantize(self.values, group_size=group_size, bits=bits)
return quant_cache
class RotatingKVCache(_BaseCache): class RotatingKVCache(_BaseCache):
@ -418,6 +440,13 @@ class RotatingKVCache(_BaseCache):
self._idx -= n self._idx -= n
return n return n
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
quant_cache = QuantizedKVCache(group_size=group_size, bits=bits)
quant_cache.offset = self.offset
quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits)
quant_cache.values = mx.quantize(self.values, group_size=group_size, bits=bits)
return quant_cache
class MambaCache(_BaseCache): class MambaCache(_BaseCache):
def __init__(self): def __init__(self):

View File

@ -7,7 +7,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask
from .cache import QuantizedKVCache from .cache import QuantizedKVCache, quantized_scaled_dot_product_attention
@dataclass @dataclass
@ -192,10 +192,10 @@ class Attention(nn.Module):
keys = self.rope(keys) keys = self.rope(keys)
if isinstance(cache, QuantizedKVCache): if isinstance(cache, QuantizedKVCache):
output = mx.fast.quantized_scaled_dot_product_attention( output = quantized_scaled_dot_product_attention(
queries, queries,
*keys, keys,
*values, values,
scale=self.scale, scale=self.scale,
mask=mask, mask=mask,
group_size=cache.group_size, group_size=cache.group_size,

View File

@ -7,6 +7,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask
from .cache import QuantizedKVCache, quantized_scaled_dot_product_attention
@dataclass @dataclass
@ -89,9 +90,20 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( if isinstance(cache, QuantizedKVCache):
queries, keys, values, scale=self.scale, mask=mask output = quantized_scaled_dot_product_attention(
) queries,
keys,
values,
scale=self.scale,
mask=mask,
group_size=cache.group_size,
bits=cache.bits,
)
else:
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -33,6 +33,9 @@ MODEL_REMAPPING = {
MAX_FILE_SIZE_GB = 5 MAX_FILE_SIZE_GB = 5
DEFAULT_KV_GROUP_SIZE = 64
DEFAULT_KV_BITS = 8
class ModelNotFoundError(Exception): class ModelNotFoundError(Exception):
def __init__(self, message): def __init__(self, message):
@ -159,6 +162,27 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float)
return logits return logits
def check_quantized_kv_args(quantized_kv_start, kv_group_size, kv_bits):
if not quantized_kv_start and (
kv_group_size != DEFAULT_KV_GROUP_SIZE or kv_bits != DEFAULT_KV_BITS
):
raise ValueError(
"--kv-group-size and --kv-bits only apply when --quantized-kv-start is specified."
)
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
if (
quantized_kv_start
and prompt_cache[0].offset > quantized_kv_start
and not isinstance(prompt_cache[0], cache.QuantizedKVCache)
):
return [
c.to_quantized(group_size=kv_group_size, bits=kv_bits) for c in prompt_cache
]
return prompt_cache
def generate_step( def generate_step(
prompt: mx.array, prompt: mx.array,
model: nn.Module, model: nn.Module,
@ -173,7 +197,7 @@ def generate_step(
prompt_cache: Optional[Any] = None, prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None, logit_bias: Optional[Dict[int, float]] = None,
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
quantized_kv: bool = False, quantized_kv_start: Optional[int] = None,
kv_group_size: int = 64, kv_group_size: int = 64,
kv_bits: int = 8, kv_bits: int = 8,
) -> Generator[Tuple[mx.array, mx.array], None, None]: ) -> Generator[Tuple[mx.array, mx.array], None, None]:
@ -261,14 +285,13 @@ def generate_step(
prompt_cache = cache.make_prompt_cache( prompt_cache = cache.make_prompt_cache(
model, model,
max_kv_size=max_kv_size, max_kv_size=max_kv_size,
quantized_kv=quantized_kv,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
) )
elif len(prompt_cache) != len(model.layers): elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.") raise ValueError("Wrong number of layers in the prompt cache.")
def _step(y): def _step(y):
nonlocal prompt_cache
logits = model(y[None], cache=prompt_cache) logits = model(y[None], cache=prompt_cache)
logits = logits[:, -1, :] logits = logits[:, -1, :]
@ -279,6 +302,10 @@ def generate_step(
for processor in logits_processor: for processor in logits_processor:
logits = processor(tokens, logits) logits = processor(tokens, logits)
prompt_cache = maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
)
y, logprobs = sample(logits) y, logprobs = sample(logits)
return y, logprobs.squeeze(0) return y, logprobs.squeeze(0)