Quantized KV Cache (#1075)

* add QuantizedKVCache

* simplify

* add tests

* single sdpa function

* fix sed

* in place

* fix tests

* support different k and v head dims
This commit is contained in:
Alex Barron
2024-10-31 16:59:52 -07:00
committed by GitHub
parent 9f34fdbda4
commit 85ffd2c96a
32 changed files with 411 additions and 85 deletions

View File

@@ -7,7 +7,7 @@ import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs, create_attention_mask
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
@@ -92,10 +92,11 @@ class Attention(nn.Module):
keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1])
values = mx.tile(values, [1, self.config.n_shared_head, 1, 1])
output = mx.fast.scaled_dot_product_attention(
output = scaled_dot_product_attention(
queries,
keys,
values,
cache=cache,
scale=self.scale,
mask=attention_mask,
)