diff --git a/llms/mlx_lm/models/gemma3_text.py b/llms/mlx_lm/models/gemma3_text.py index 291d12fe..be71f461 100644 --- a/llms/mlx_lm/models/gemma3_text.py +++ b/llms/mlx_lm/models/gemma3_text.py @@ -1,3 +1,5 @@ +# Copyright © 2025 Apple Inc. + from dataclasses import dataclass from typing import Any, Optional @@ -27,7 +29,6 @@ class ModelArgs(BaseModelArgs): sliding_window_pattern: int = 6 - class Attention(nn.Module): def __init__(self, args: ModelArgs, layer_idx: int): super().__init__() @@ -85,10 +86,8 @@ class Attention(nn.Module): keys = self.rope(keys) # Sliding window - if self.is_sliding and mask is not None: - key_len = keys.shape[-2] - if mask.shape[-1] != key_len: - mask = mask[..., -key_len:] + if mask is not None and mask.shape[-1] != keys.shape[-2]: + mask = mask[..., -keys.shape[-2] :] output = mx.fast.scaled_dot_product_attention( queries, keys, values, scale=self.scale, mask=mask @@ -169,16 +168,27 @@ class Gemma3Model(nn.Module): ): h = self.embed_tokens(inputs) - h *= self.args.hidden_size**0.5 # persistent precision issue in scaling + h *= mx.array(self.args.hidden_size**0.5, mx.bfloat16).astype(h.dtype) if cache is None: cache = [None] * len(self.layers) if mask is None: j = self.args.sliding_window_pattern - mask = create_attention_mask(h, cache[j - 1 : j]) + full_mask = create_attention_mask(h, cache[j - 1 : j]) + sliding_window_mask = create_attention_mask(h, cache) + + for i, (layer, c) in enumerate(zip(self.layers, cache)): + is_sliding = ( + i % self.args.sliding_window_pattern + == self.args.sliding_window_pattern - 1 + ) + + if mask is None and is_sliding: + mask = sliding_window_mask + elif mask is None: + mask = full_mask - for layer, c in zip(self.layers, cache): h = layer(h, mask, c) return self.norm(h) @@ -213,7 +223,6 @@ class Model(nn.Module): def layers(self): return self.model.layers - def make_cache(self): caches = [] for i in range(self.args.num_hidden_layers):