This commit is contained in:
Alex Barron
2024-10-28 16:03:43 -07:00
parent 48655a7f83
commit 37a3723823
6 changed files with 197 additions and 90 deletions

View File

@@ -7,7 +7,7 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask
from .cache import QuantizedKVCache
from .cache import QuantizedKVCache, quantized_scaled_dot_product_attention
@dataclass
@@ -192,10 +192,10 @@ class Attention(nn.Module):
keys = self.rope(keys)
if isinstance(cache, QuantizedKVCache):
output = mx.fast.quantized_scaled_dot_product_attention(
output = quantized_scaled_dot_product_attention(
queries,
*keys,
*values,
keys,
values,
scale=self.scale,
mask=mask,
group_size=cache.group_size,