mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 02:21:18 +08:00
Remove sliding window attention impl. cause it should be done by using RotatingKVCache
This commit is contained in:
parent
ab960f80dd
commit
8924bdc546
@ -354,15 +354,6 @@ class Mamba(nn.Module):
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
def swa_mask(q_len: int, kv_len: int, window_size: int) -> mx.array:
|
|
||||||
max_len = max(q_len, kv_len)
|
|
||||||
mask = mx.tril(
|
|
||||||
mx.triu(mx.ones((max_len, max_len), dtype=mx.bool_), k=-window_size), # type: ignore
|
|
||||||
k=window_size,
|
|
||||||
)
|
|
||||||
return mask[-q_len:, -kv_len:]
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, config: ModelArgs) -> None:
|
def __init__(self, config: ModelArgs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -422,23 +413,6 @@ class Attention(nn.Module):
|
|||||||
q = self.rope(q)
|
q = self.rope(q)
|
||||||
k = self.rope(k)
|
k = self.rope(k)
|
||||||
|
|
||||||
if mask is not None:
|
|
||||||
if mask.dtype == bool:
|
|
||||||
mask = mx.where(mask, mx.array(0.0, dtype=mx.float16), float("-inf"))
|
|
||||||
if len(mask.shape) == 2:
|
|
||||||
mask = mask[None, None]
|
|
||||||
assert len(mask.shape) == 4
|
|
||||||
|
|
||||||
m_swa = swa_mask(
|
|
||||||
q.shape[2],
|
|
||||||
k.shape[2],
|
|
||||||
self.config.attention_window_size,
|
|
||||||
)
|
|
||||||
# `generate` function creates attention mask that does not consider sliding window
|
|
||||||
m_swa = m_swa[None, None]
|
|
||||||
mask = mask[:, :, -q.shape[2] :, -k.shape[2] :]
|
|
||||||
mask = mx.where(m_swa, mask, float("-inf"))
|
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
|
Loading…
Reference in New Issue
Block a user