mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
simplify
This commit is contained in:
parent
48655a7f83
commit
37a3723823
@ -8,7 +8,13 @@ import time
|
||||
import mlx.core as mx
|
||||
|
||||
from .models.cache import make_prompt_cache, save_prompt_cache
|
||||
from .utils import load
|
||||
from .utils import (
|
||||
DEFAULT_KV_BITS,
|
||||
DEFAULT_KV_GROUP_SIZE,
|
||||
check_quantized_kv_args,
|
||||
load,
|
||||
maybe_quantize_kv_cache,
|
||||
)
|
||||
|
||||
|
||||
def setup_arg_parser():
|
||||
@ -70,6 +76,24 @@ def setup_arg_parser():
|
||||
required=True,
|
||||
help="Message to be processed by the model ('-' reads from stdin)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantized-kv-start",
|
||||
help="Use a quantized KV cache from this step onwards.",
|
||||
type=int,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-group-size",
|
||||
type=int,
|
||||
help="Group size for kv cache quantization.",
|
||||
default=DEFAULT_KV_GROUP_SIZE,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-bits",
|
||||
type=int,
|
||||
help="Number of bits for kv cache quantization.",
|
||||
default=DEFAULT_KV_BITS,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@ -93,6 +117,8 @@ def main():
|
||||
|
||||
args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt
|
||||
|
||||
check_quantized_kv_args(args.quantized_kv_start, args.kv_group_size, args.kv_bits)
|
||||
|
||||
if args.use_default_chat_template:
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
@ -127,6 +153,7 @@ def main():
|
||||
start = time.time()
|
||||
max_msg_len = 0
|
||||
while y.size > 0:
|
||||
|
||||
model(y[:step_size][None], cache=cache)
|
||||
mx.eval([c.state for c in cache])
|
||||
processed += min(y.size, step_size)
|
||||
@ -136,6 +163,11 @@ def main():
|
||||
msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
|
||||
max_msg_len = max(max_msg_len, len(msg))
|
||||
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)
|
||||
|
||||
cache = maybe_quantize_kv_cache(
|
||||
cache, args.quantized_kv_start, args.kv_group_size, args.kv_bits
|
||||
)
|
||||
|
||||
print()
|
||||
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
|
||||
|
||||
|
@ -6,8 +6,8 @@ import sys
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .models.cache import load_prompt_cache
|
||||
from .utils import generate, load
|
||||
from .models.cache import QuantizedKVCache, load_prompt_cache
|
||||
from .utils import check_quantized_kv_args, generate, load
|
||||
|
||||
DEFAULT_PROMPT = "hello"
|
||||
DEFAULT_MAX_TOKENS = 100
|
||||
@ -108,20 +108,23 @@ def setup_arg_parser():
|
||||
help="A file containing saved KV caches to avoid recomputing them",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantized-kv",
|
||||
help="Whether to quantize the KV cache.",
|
||||
action="store_true",
|
||||
"--quantized-kv-start",
|
||||
help="Use a quantized KV cache from this step onwards.",
|
||||
type=int,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-group-size",
|
||||
type=int,
|
||||
help="Group size for kv cache quantization.",
|
||||
help="Group size for kv cache quantization. "
|
||||
"--quantized-kv-start must be provided to have an effect.",
|
||||
default=64,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-bits",
|
||||
type=int,
|
||||
help="Number of bits for kv cache quantization.",
|
||||
help="Number of bits for kv cache quantization. "
|
||||
"--quantized-kv-start must be provided to have an effect.",
|
||||
default=8,
|
||||
)
|
||||
return parser
|
||||
@ -169,10 +172,14 @@ def main():
|
||||
prompt_cache, metadata = load_prompt_cache(
|
||||
args.prompt_cache_file,
|
||||
return_metadata=True,
|
||||
quantized_kv=args.quantized_kv,
|
||||
kv_group_size=args.kv_group_size,
|
||||
kv_bits=args.kv_bits,
|
||||
)
|
||||
if args.quantized_kv_start and isinstance(prompt_cache[0], QuantizedKVCache):
|
||||
raise ValueError(
|
||||
"Specified `--quantized-kv-start` but cache from "
|
||||
"`--prompt-cache-file` is already quantized."
|
||||
)
|
||||
|
||||
check_quantized_kv_args(args.quantized_kv_start, args.kv_group_size, args.kv_bits)
|
||||
|
||||
# Building tokenizer_config
|
||||
tokenizer_config = (
|
||||
@ -248,7 +255,7 @@ def main():
|
||||
top_p=args.top_p,
|
||||
max_kv_size=args.max_kv_size,
|
||||
prompt_cache=prompt_cache if using_cache else None,
|
||||
quantized_kv=args.quantized_kv,
|
||||
quantized_kv_start=args.quantized_kv_start,
|
||||
kv_group_size=args.kv_group_size,
|
||||
kv_bits=args.kv_bits,
|
||||
)
|
||||
|
@ -4,15 +4,12 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
|
||||
def make_prompt_cache(
|
||||
model: nn.Module,
|
||||
max_kv_size: Optional[int] = None,
|
||||
quantized_kv: bool = False,
|
||||
kv_group_size: int = 64,
|
||||
kv_bits: int = 8,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Construct the model's cache for use when cgeneration.
|
||||
@ -30,12 +27,7 @@ def make_prompt_cache(
|
||||
return model.make_cache()
|
||||
|
||||
num_layers = len(model.layers)
|
||||
if quantized_kv:
|
||||
return [
|
||||
QuantizedKVCache(group_size=kv_group_size, bits=kv_bits)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
elif max_kv_size is not None:
|
||||
if max_kv_size is not None:
|
||||
return [
|
||||
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
|
||||
]
|
||||
@ -62,9 +54,7 @@ def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str]
|
||||
mx.save_safetensors(file_name, cache_data, cache_metadata)
|
||||
|
||||
|
||||
def load_prompt_cache(
|
||||
file_name, return_metadata=False, quantized_kv=False, kv_group_size=64, kv_bits=8
|
||||
):
|
||||
def load_prompt_cache(file_name, return_metadata=False):
|
||||
"""
|
||||
Load a prompt cache from a file.
|
||||
|
||||
@ -85,8 +75,6 @@ def load_prompt_cache(
|
||||
for c, state, meta_state in zip(cache, arrays, info):
|
||||
c.state = state
|
||||
c.meta_state = meta_state
|
||||
if quantized_kv:
|
||||
cache = [c.to_quantized(group_size=kv_group_size, bits=kv_bits) for c in cache]
|
||||
if return_metadata:
|
||||
return cache, metadata
|
||||
return cache
|
||||
@ -141,8 +129,44 @@ class _BaseCache:
|
||||
return False
|
||||
|
||||
|
||||
def quantized_scaled_dot_product_attention(
|
||||
queries: mx.array,
|
||||
q_keys: tuple[mx.array, mx.array, mx.array],
|
||||
q_values: tuple[mx.array, mx.array, mx.array],
|
||||
scale: float,
|
||||
mask: Optional[mx.array],
|
||||
group_size: int = 64,
|
||||
bits: int = 8,
|
||||
) -> mx.array:
|
||||
B, n_q_heads, L, D = queries.shape
|
||||
n_kv_heads = q_keys[0].shape[-3]
|
||||
n_repeats = n_q_heads // n_kv_heads
|
||||
|
||||
queries *= scale
|
||||
|
||||
if n_repeats > 1:
|
||||
queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D))
|
||||
q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
|
||||
q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values)
|
||||
|
||||
scores = mx.quantized_matmul(
|
||||
queries, *q_keys, transpose=True, group_size=group_size, bits=bits
|
||||
)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
out = mx.quantized_matmul(
|
||||
scores, *q_values, transpose=False, group_size=group_size, bits=bits
|
||||
)
|
||||
|
||||
if n_repeats > 1:
|
||||
out = mx.reshape(out, (B, n_q_heads, L, D))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class QuantizedKVCache(_BaseCache):
|
||||
def __init__(self, group_size: int = 64, bits: int = 4):
|
||||
def __init__(self, group_size: int = 64, bits: int = 8):
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
@ -154,71 +178,65 @@ class QuantizedKVCache(_BaseCache):
|
||||
B, n_kv_heads, num_steps, k_head_dim = keys.shape
|
||||
prev = self.offset
|
||||
|
||||
if self.keys is None or (prev + num_steps) > self.keys[0].shape[2]:
|
||||
if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]:
|
||||
el_per_int = 8 * mx.uint32.size // self.bits
|
||||
n_steps = (self.step + keys[0].shape[2] - 1) // self.step
|
||||
new_steps = (self.step + num_steps - 1) // self.step * self.step
|
||||
shape = (B, n_kv_heads, new_steps, k_head_dim // el_per_int)
|
||||
group_shape = (B, n_kv_heads, new_steps, k_head_dim // self.group_size)
|
||||
|
||||
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim // el_per_int)
|
||||
scales_dim = k_head_dim // self.group_size
|
||||
k_scale_shape = k_shape[:-1] + (scales_dim,)
|
||||
v_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim // el_per_int)
|
||||
def init_quant():
|
||||
return (
|
||||
mx.zeros(shape, dtype=mx.uint32),
|
||||
mx.zeros(group_shape, dtype=keys.dtype),
|
||||
mx.zeros(group_shape, dtype=keys.dtype),
|
||||
)
|
||||
|
||||
scale_bias_init = lambda: mx.zeros(k_scale_shape, keys.dtype)
|
||||
new_k = (mx.zeros(k_shape, mx.uint32), scale_bias_init(), scale_bias_init())
|
||||
new_v = (mx.zeros(v_shape, mx.uint32), scale_bias_init(), scale_bias_init())
|
||||
def expand_quant(x):
|
||||
new_x = mx.zeros((B, n_kv_heads, new_steps, x.shape[-1]), dtype=x.dtype)
|
||||
return mx.concatenate([x, new_x], axis=-2)
|
||||
|
||||
if self.keys is not None:
|
||||
if prev % self.step != 0:
|
||||
self.keys = tuple(x[..., :prev, :] for x in self.keys)
|
||||
self.values = tuple(x[..., :prev, :] for x in self.values)
|
||||
self.keys = tuple(
|
||||
mx.concatenate([self.keys[i], new_k[i]], axis=2) for i in range(3)
|
||||
)
|
||||
self.values = tuple(
|
||||
mx.concatenate([self.values[i], new_v[i]], axis=2) for i in range(3)
|
||||
self.keys, self.values = tree_map(
|
||||
lambda x: x[..., :prev, :], (self.keys, self.values)
|
||||
)
|
||||
|
||||
self.keys, self.values = tree_map(
|
||||
expand_quant, (self.keys, self.values)
|
||||
)
|
||||
else:
|
||||
self.keys, self.values = new_k, new_v
|
||||
self.keys, self.values = init_quant(), init_quant()
|
||||
|
||||
self.offset += num_steps
|
||||
|
||||
if num_steps > 1:
|
||||
keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits)
|
||||
values = mx.quantize(values, group_size=self.group_size, bits=self.bits)
|
||||
for i in range(len(self.keys)):
|
||||
self.keys[i][..., prev : self.offset, :] = keys[i]
|
||||
self.values[i][..., prev : self.offset, :] = values[i]
|
||||
keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits)
|
||||
values = mx.quantize(values, group_size=self.group_size, bits=self.bits)
|
||||
for i in range(len(self.keys)):
|
||||
self.keys[i][..., prev : self.offset, :] = keys[i]
|
||||
self.values[i][..., prev : self.offset, :] = values[i]
|
||||
|
||||
else:
|
||||
outputs = mx.fast.quantized_kv_update(
|
||||
keys,
|
||||
values,
|
||||
*self.keys,
|
||||
*self.values,
|
||||
prev,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits
|
||||
)
|
||||
self.keys = outputs[:3]
|
||||
self.values = outputs[3:]
|
||||
|
||||
return (
|
||||
tuple(x[..., : self.offset, :] for x in self.keys),
|
||||
tuple(x[..., : self.offset, :] for x in self.values),
|
||||
)
|
||||
return tree_map(lambda x: x[..., : self.offset, :], (self.keys, self.values))
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self.keys, self.values
|
||||
if self.offset == self.keys[0].shape[2]:
|
||||
return self.keys, self.values
|
||||
else:
|
||||
return tree_map(
|
||||
lambda x: x[..., : self.offset, :], (self.keys, self.values)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cache(
|
||||
cls, cache: _BaseCache, group_size: int = 64, bits: int = 4
|
||||
) -> "QuantizedKVCache":
|
||||
quant_cache = cls(group_size=group_size, bits=bits)
|
||||
quant_cache.offset = cache.offset
|
||||
quant_cache.keys = mx.quantize(cache.keys, group_size=group_size, bits=bits)
|
||||
quant_cache.values = mx.quantize(cache.values, group_size=group_size, bits=bits)
|
||||
return quant_cache
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
self.keys, self.values = v
|
||||
|
||||
@property
|
||||
def meta_state(self):
|
||||
return tuple(map(str, (self.step, self.offset, self.group_size, self.bits)))
|
||||
|
||||
@meta_state.setter
|
||||
def meta_state(self, v):
|
||||
self.step, self.offset, self.group_size, self.bits = map(int, v)
|
||||
|
||||
|
||||
class KVCache(_BaseCache):
|
||||
@ -276,7 +294,11 @@ class KVCache(_BaseCache):
|
||||
return n
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
||||
return QuantizedKVCache.from_cache(self, group_size=group_size, bits=bits)
|
||||
quant_cache = QuantizedKVCache(group_size=group_size, bits=bits)
|
||||
quant_cache.offset = self.offset
|
||||
quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits)
|
||||
quant_cache.values = mx.quantize(self.values, group_size=group_size, bits=bits)
|
||||
return quant_cache
|
||||
|
||||
|
||||
class RotatingKVCache(_BaseCache):
|
||||
@ -418,6 +440,13 @@ class RotatingKVCache(_BaseCache):
|
||||
self._idx -= n
|
||||
return n
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
||||
quant_cache = QuantizedKVCache(group_size=group_size, bits=bits)
|
||||
quant_cache.offset = self.offset
|
||||
quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits)
|
||||
quant_cache.values = mx.quantize(self.values, group_size=group_size, bits=bits)
|
||||
return quant_cache
|
||||
|
||||
|
||||
class MambaCache(_BaseCache):
|
||||
def __init__(self):
|
||||
|
@ -7,7 +7,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
from .cache import QuantizedKVCache
|
||||
from .cache import QuantizedKVCache, quantized_scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -192,10 +192,10 @@ class Attention(nn.Module):
|
||||
keys = self.rope(keys)
|
||||
|
||||
if isinstance(cache, QuantizedKVCache):
|
||||
output = mx.fast.quantized_scaled_dot_product_attention(
|
||||
output = quantized_scaled_dot_product_attention(
|
||||
queries,
|
||||
*keys,
|
||||
*values,
|
||||
keys,
|
||||
values,
|
||||
scale=self.scale,
|
||||
mask=mask,
|
||||
group_size=cache.group_size,
|
||||
|
@ -7,6 +7,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
from .cache import QuantizedKVCache, quantized_scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -89,9 +90,20 @@ class Attention(nn.Module):
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
if isinstance(cache, QuantizedKVCache):
|
||||
output = quantized_scaled_dot_product_attention(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
scale=self.scale,
|
||||
mask=mask,
|
||||
group_size=cache.group_size,
|
||||
bits=cache.bits,
|
||||
)
|
||||
else:
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
@ -33,6 +33,9 @@ MODEL_REMAPPING = {
|
||||
|
||||
MAX_FILE_SIZE_GB = 5
|
||||
|
||||
DEFAULT_KV_GROUP_SIZE = 64
|
||||
DEFAULT_KV_BITS = 8
|
||||
|
||||
|
||||
class ModelNotFoundError(Exception):
|
||||
def __init__(self, message):
|
||||
@ -159,6 +162,27 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float)
|
||||
return logits
|
||||
|
||||
|
||||
def check_quantized_kv_args(quantized_kv_start, kv_group_size, kv_bits):
|
||||
if not quantized_kv_start and (
|
||||
kv_group_size != DEFAULT_KV_GROUP_SIZE or kv_bits != DEFAULT_KV_BITS
|
||||
):
|
||||
raise ValueError(
|
||||
"--kv-group-size and --kv-bits only apply when --quantized-kv-start is specified."
|
||||
)
|
||||
|
||||
|
||||
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
|
||||
if (
|
||||
quantized_kv_start
|
||||
and prompt_cache[0].offset > quantized_kv_start
|
||||
and not isinstance(prompt_cache[0], cache.QuantizedKVCache)
|
||||
):
|
||||
return [
|
||||
c.to_quantized(group_size=kv_group_size, bits=kv_bits) for c in prompt_cache
|
||||
]
|
||||
return prompt_cache
|
||||
|
||||
|
||||
def generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
@ -173,7 +197,7 @@ def generate_step(
|
||||
prompt_cache: Optional[Any] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
quantized_kv: bool = False,
|
||||
quantized_kv_start: Optional[int] = None,
|
||||
kv_group_size: int = 64,
|
||||
kv_bits: int = 8,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
@ -261,14 +285,13 @@ def generate_step(
|
||||
prompt_cache = cache.make_prompt_cache(
|
||||
model,
|
||||
max_kv_size=max_kv_size,
|
||||
quantized_kv=quantized_kv,
|
||||
kv_group_size=kv_group_size,
|
||||
kv_bits=kv_bits,
|
||||
)
|
||||
elif len(prompt_cache) != len(model.layers):
|
||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||
|
||||
def _step(y):
|
||||
|
||||
nonlocal prompt_cache
|
||||
logits = model(y[None], cache=prompt_cache)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
@ -279,6 +302,10 @@ def generate_step(
|
||||
for processor in logits_processor:
|
||||
logits = processor(tokens, logits)
|
||||
|
||||
prompt_cache = maybe_quantize_kv_cache(
|
||||
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||
)
|
||||
|
||||
y, logprobs = sample(logits)
|
||||
return y, logprobs.squeeze(0)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user