mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Quantized KV Cache (#1075)
* add QuantizedKVCache * simplify * add tests * single sdpa function * fix sed * in place * fix tests * support different k and v head dims
This commit is contained in:
@@ -5,6 +5,9 @@ from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
|
||||
from .cache import QuantizedKVCache
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -48,3 +51,63 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
||||
else:
|
||||
mask = None
|
||||
return mask
|
||||
|
||||
|
||||
def quantized_scaled_dot_product_attention(
|
||||
queries: mx.array,
|
||||
q_keys: tuple[mx.array, mx.array, mx.array],
|
||||
q_values: tuple[mx.array, mx.array, mx.array],
|
||||
scale: float,
|
||||
mask: Optional[mx.array],
|
||||
group_size: int = 64,
|
||||
bits: int = 8,
|
||||
) -> mx.array:
|
||||
B, n_q_heads, L, D = queries.shape
|
||||
n_kv_heads = q_keys[0].shape[-3]
|
||||
n_repeats = n_q_heads // n_kv_heads
|
||||
|
||||
queries *= scale
|
||||
|
||||
if n_repeats > 1:
|
||||
queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D))
|
||||
q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
|
||||
q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values)
|
||||
|
||||
scores = mx.quantized_matmul(
|
||||
queries, *q_keys, transpose=True, group_size=group_size, bits=bits
|
||||
)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
out = mx.quantized_matmul(
|
||||
scores, *q_values, transpose=False, group_size=group_size, bits=bits
|
||||
)
|
||||
|
||||
if n_repeats > 1:
|
||||
out = mx.reshape(out, (B, n_q_heads, L, D))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def scaled_dot_product_attention(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
cache,
|
||||
scale: float,
|
||||
mask: Optional[mx.array],
|
||||
) -> mx.array:
|
||||
if isinstance(cache, QuantizedKVCache):
|
||||
return quantized_scaled_dot_product_attention(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
scale=scale,
|
||||
mask=mask,
|
||||
group_size=cache.group_size,
|
||||
bits=cache.bits,
|
||||
)
|
||||
else:
|
||||
return mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=scale, mask=mask
|
||||
)
|
||||
|
Reference in New Issue
Block a user