mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
revert is sliding pattern
This commit is contained in:
parent
645b666890
commit
2d30f6787a
@ -60,7 +60,7 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
self.q_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
|
self.q_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
|
||||||
self.k_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
|
self.k_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
|
||||||
self.is_sliding = (layer_idx + 1) % args.sliding_window_pattern == 0
|
self.is_sliding = (layer_idx + 1) % args.sliding_window_pattern != 0
|
||||||
|
|
||||||
self.rope = nn.RoPE(
|
self.rope = nn.RoPE(
|
||||||
head_dim,
|
head_dim,
|
||||||
@ -102,6 +102,7 @@ class Attention(nn.Module):
|
|||||||
if mask.shape[-1] != key_len:
|
if mask.shape[-1] != key_len:
|
||||||
mask = mask[..., :key_len]
|
mask = mask[..., :key_len]
|
||||||
|
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user