diff --git a/llms/mlx_lm/models/gemma3_text.py b/llms/mlx_lm/models/gemma3_text.py index b0f76e16..5d7e312d 100644 --- a/llms/mlx_lm/models/gemma3_text.py +++ b/llms/mlx_lm/models/gemma3_text.py @@ -100,7 +100,7 @@ class Attention(nn.Module): if self.is_sliding and mask is not None: key_len = keys.shape[-2] if mask.shape[-1] != key_len: - mask = mask[..., :key_len] + mask = mask[..., -key_len:] output = mx.fast.scaled_dot_product_attention(