mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 09:56:24 +08:00
fix sliding window
This commit is contained in:
parent
0337646b4e
commit
d7d70487eb
@ -85,20 +85,25 @@ class Attention(nn.Module):
|
|||||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
if cache is not None:
|
# Apply RoPE only if sliding window is enabled
|
||||||
queries = self.rope(queries, offset=cache.offset)
|
if self.sliding_window is not None:
|
||||||
keys = self.rope(keys, offset=cache.offset)
|
if cache is None:
|
||||||
keys, values = cache.update_and_fetch(keys, values)
|
queries = self.rope(queries)
|
||||||
else:
|
keys = self.rope(keys)
|
||||||
queries = self.rope(queries)
|
else:
|
||||||
keys = self.rope(keys)
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
|
||||||
# sliding window attention
|
if cache is not None:
|
||||||
if self.sliding_window is not None:
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
keys = keys[:, :, -self.sliding_window :, :]
|
|
||||||
values = values[:, :, -self.sliding_window :, :]
|
# Apply sliding window attention if enabled
|
||||||
if mask is not None:
|
if self.sliding_window is not None:
|
||||||
mask = mask[:, -self.sliding_window :]
|
window_size = self.sliding_window
|
||||||
|
keys = keys[..., -window_size:, :]
|
||||||
|
values = values[..., -window_size:, :]
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask[..., -window_size:]
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
|
Loading…
Reference in New Issue
Block a user