fix sliding window

This commit is contained in:
Prince Canuma 2024-12-14 17:06:57 +01:00
parent 0337646b4e
commit d7d70487eb

View File

@ -85,20 +85,25 @@ class Attention(nn.Module):
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None: # Apply RoPE only if sliding window is enabled
queries = self.rope(queries, offset=cache.offset) if self.sliding_window is not None:
keys = self.rope(keys, offset=cache.offset) if cache is None:
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
else:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
# sliding window attention if cache is not None:
keys, values = cache.update_and_fetch(keys, values)
# Apply sliding window attention if enabled
if self.sliding_window is not None: if self.sliding_window is not None:
keys = keys[:, :, -self.sliding_window :, :] window_size = self.sliding_window
values = values[:, :, -self.sliding_window :, :] keys = keys[..., -window_size:, :]
values = values[..., -window_size:, :]
if mask is not None: if mask is not None:
mask = mask[:, -self.sliding_window :] mask = mask[..., -window_size:]
output = mx.fast.scaled_dot_product_attention( output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, scale=self.scale, mask=mask