Quantized KV Cache (#1075)

* add QuantizedKVCache

* simplify

* add tests

* single sdpa function

* fix sed

* in place

* fix tests

* support different k and v head dims
This commit is contained in:
Alex Barron
2024-10-31 16:59:52 -07:00
committed by GitHub
parent 9f34fdbda4
commit 85ffd2c96a
32 changed files with 411 additions and 85 deletions

View File

@@ -1,12 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
@@ -190,9 +190,10 @@ class Attention(nn.Module):
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)