revert is sliding pattern

This commit is contained in:
Prince Canuma 2025-03-12 09:48:14 +01:00
parent 645b666890
commit 2d30f6787a

View File

@ -60,7 +60,7 @@ class Attention(nn.Module):
self.q_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps) self.q_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
self.k_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps) self.k_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
self.is_sliding = (layer_idx + 1) % args.sliding_window_pattern == 0 self.is_sliding = (layer_idx + 1) % args.sliding_window_pattern != 0
self.rope = nn.RoPE( self.rope = nn.RoPE(
head_dim, head_dim,
@ -102,6 +102,7 @@ class Attention(nn.Module):
if mask.shape[-1] != key_len: if mask.shape[-1] != key_len:
mask = mask[..., :key_len] mask = mask[..., :key_len]
output = mx.fast.scaled_dot_product_attention( output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, scale=self.scale, mask=mask
) )