single sdpa function

This commit is contained in:
Alex Barron
2024-10-31 12:02:34 -07:00
parent 29f21e7fe4
commit 2e0690374e
31 changed files with 174 additions and 191 deletions

View File

@@ -6,8 +6,7 @@ from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask
from .cache import QuantizedKVCache, quantized_scaled_dot_product_attention
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
@@ -191,20 +190,9 @@ class Attention(nn.Module):
queries = self.rope(queries)
keys = self.rope(keys)
if isinstance(cache, QuantizedKVCache):
output = quantized_scaled_dot_product_attention(
queries,
keys,
values,
scale=self.scale,
mask=mask,
group_size=cache.group_size,
bits=cache.bits,
)
else:
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)