mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 15:50:57 +08:00
Remove sliding window attention impl. cause it should be done by using RotatingKVCache
This commit is contained in:
parent
ab960f80dd
commit
8924bdc546
@ -354,15 +354,6 @@ class Mamba(nn.Module):
|
||||
return y
|
||||
|
||||
|
||||
def swa_mask(q_len: int, kv_len: int, window_size: int) -> mx.array:
|
||||
max_len = max(q_len, kv_len)
|
||||
mask = mx.tril(
|
||||
mx.triu(mx.ones((max_len, max_len), dtype=mx.bool_), k=-window_size), # type: ignore
|
||||
k=window_size,
|
||||
)
|
||||
return mask[-q_len:, -kv_len:]
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
@ -422,23 +413,6 @@ class Attention(nn.Module):
|
||||
q = self.rope(q)
|
||||
k = self.rope(k)
|
||||
|
||||
if mask is not None:
|
||||
if mask.dtype == bool:
|
||||
mask = mx.where(mask, mx.array(0.0, dtype=mx.float16), float("-inf"))
|
||||
if len(mask.shape) == 2:
|
||||
mask = mask[None, None]
|
||||
assert len(mask.shape) == 4
|
||||
|
||||
m_swa = swa_mask(
|
||||
q.shape[2],
|
||||
k.shape[2],
|
||||
self.config.attention_window_size,
|
||||
)
|
||||
# `generate` function creates attention mask that does not consider sliding window
|
||||
m_swa = m_swa[None, None]
|
||||
mask = mask[:, :, -q.shape[2] :, -k.shape[2] :]
|
||||
mask = mx.where(m_swa, mask, float("-inf"))
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
|
Loading…
Reference in New Issue
Block a user