diff --git a/llms/mlx_lm/models/plamo2.py b/llms/mlx_lm/models/plamo2.py index 14983350..def84221 100644 --- a/llms/mlx_lm/models/plamo2.py +++ b/llms/mlx_lm/models/plamo2.py @@ -354,15 +354,6 @@ class Mamba(nn.Module): return y -def swa_mask(q_len: int, kv_len: int, window_size: int) -> mx.array: - max_len = max(q_len, kv_len) - mask = mx.tril( - mx.triu(mx.ones((max_len, max_len), dtype=mx.bool_), k=-window_size), # type: ignore - k=window_size, - ) - return mask[-q_len:, -kv_len:] - - class Attention(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() @@ -422,23 +413,6 @@ class Attention(nn.Module): q = self.rope(q) k = self.rope(k) - if mask is not None: - if mask.dtype == bool: - mask = mx.where(mask, mx.array(0.0, dtype=mx.float16), float("-inf")) - if len(mask.shape) == 2: - mask = mask[None, None] - assert len(mask.shape) == 4 - - m_swa = swa_mask( - q.shape[2], - k.shape[2], - self.config.attention_window_size, - ) - # `generate` function creates attention mask that does not consider sliding window - m_swa = m_swa[None, None] - mask = mask[:, :, -q.shape[2] :, -k.shape[2] :] - mask = mx.where(m_swa, mask, float("-inf")) - output = mx.fast.scaled_dot_product_attention( q, k,