Quantized KV Cache (#1075)

* add QuantizedKVCache

* simplify

* add tests

* single sdpa function

* fix sed

* in place

* fix tests

* support different k and v head dims
This commit is contained in:
Alex Barron
2024-10-31 16:59:52 -07:00
committed by GitHub
parent 9f34fdbda4
commit 85ffd2c96a
32 changed files with 411 additions and 85 deletions

View File

@@ -8,7 +8,7 @@ from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import create_attention_mask
from .base import create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchMLP
@@ -71,8 +71,13 @@ class RoPEAttention(nn.Module):
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
output = mx.fast.scaled_dot_product_attention(
queries.astype(mx.float32), keys, values, scale=scale, mask=mask
output = scaled_dot_product_attention(
queries.astype(mx.float32),
keys,
values,
cache=cache,
scale=scale,
mask=mask,
).astype(values.dtype)
output = output.moveaxis(2, 1).reshape(B, L, -1)