diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py index 6b46ddc6..c6f3e885 100644 --- a/llms/mlx_lm/models/cohere2.py +++ b/llms/mlx_lm/models/cohere2.py @@ -85,20 +85,25 @@ class Attention(nn.Module): keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) + # Apply RoPE only if sliding window is enabled + if self.sliding_window is not None: + if cache is None: + queries = self.rope(queries) + keys = self.rope(keys) + else: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) - # sliding window attention - if self.sliding_window is not None: - keys = keys[:, :, -self.sliding_window :, :] - values = values[:, :, -self.sliding_window :, :] - if mask is not None: - mask = mask[:, -self.sliding_window :] + if cache is not None: + keys, values = cache.update_and_fetch(keys, values) + + # Apply sliding window attention if enabled + if self.sliding_window is not None: + window_size = self.sliding_window + keys = keys[..., -window_size:, :] + values = values[..., -window_size:, :] + if mask is not None: + mask = mask[..., -window_size:] output = mx.fast.scaled_dot_product_attention( queries, keys, values, scale=self.scale, mask=mask