Remove sliding window attention impl. cause it should be done by using RotatingKVCache

This commit is contained in:
Shunta Saito 2025-02-28 03:55:52 +09:00
parent ab960f80dd
commit 8924bdc546

View File

@ -354,15 +354,6 @@ class Mamba(nn.Module):
return y return y
def swa_mask(q_len: int, kv_len: int, window_size: int) -> mx.array:
max_len = max(q_len, kv_len)
mask = mx.tril(
mx.triu(mx.ones((max_len, max_len), dtype=mx.bool_), k=-window_size), # type: ignore
k=window_size,
)
return mask[-q_len:, -kv_len:]
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, config: ModelArgs) -> None: def __init__(self, config: ModelArgs) -> None:
super().__init__() super().__init__()
@ -422,23 +413,6 @@ class Attention(nn.Module):
q = self.rope(q) q = self.rope(q)
k = self.rope(k) k = self.rope(k)
if mask is not None:
if mask.dtype == bool:
mask = mx.where(mask, mx.array(0.0, dtype=mx.float16), float("-inf"))
if len(mask.shape) == 2:
mask = mask[None, None]
assert len(mask.shape) == 4
m_swa = swa_mask(
q.shape[2],
k.shape[2],
self.config.attention_window_size,
)
# `generate` function creates attention mask that does not consider sliding window
m_swa = m_swa[None, None]
mask = mask[:, :, -q.shape[2] :, -k.shape[2] :]
mask = mx.where(m_swa, mask, float("-inf"))
output = mx.fast.scaled_dot_product_attention( output = mx.fast.scaled_dot_product_attention(
q, q,
k, k,