diff --git a/llms/mlx_lm/models/gemma3.py b/llms/mlx_lm/models/gemma3_text.py similarity index 86% rename from llms/mlx_lm/models/gemma3.py rename to llms/mlx_lm/models/gemma3_text.py index 4e43540e..7e74e7b9 100644 --- a/llms/mlx_lm/models/gemma3.py +++ b/llms/mlx_lm/models/gemma3_text.py @@ -12,19 +12,19 @@ from .base import BaseModelArgs, create_attention_mask @dataclass class ModelArgs(BaseModelArgs): model_type: str - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int = 8 + hidden_size: int =1152 + num_hidden_layers: int = 26 + intermediate_size: int = 6912 + num_attention_heads: int = 4 head_dim: int = 256 rms_norm_eps: float = 1.0e-6 - vocab_size: int = 262208 - num_key_value_heads: int = 4 + vocab_size: int = 262144 + num_key_value_heads: int = 1 rope_global_base_freq: float = 1_000_000.0 rope_local_base_freq: float = 10_000.0 rope_traditional: bool = False - query_pre_attn_scalar: float = 0.0625 - sliding_window: int = 1024 + query_pre_attn_scalar: float = 256 + sliding_window: int = 512 rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None mm_tokens_per_image: int = 256 sliding_window_pattern: int = 6 @@ -60,7 +60,7 @@ class Attention(nn.Module): self.q_norm = nn.RMSNorm(dims=head_dim, eps=args.rms_norm_eps) self.k_norm = nn.RMSNorm(dims=head_dim, eps=args.rms_norm_eps) - self.is_sliding = (layer_idx + 1) % args.sliding_window_pattern != 0 + self.is_sliding = (layer_idx + 1) % args.sliding_window_pattern == 0 self.rope = nn.RoPE( head_dim, @@ -118,7 +118,7 @@ class MLP(nn.Module): def __call__(self, x) -> mx.array: # This should not be GELU approx, jax.nn.gelu - return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x)) + return self.down_proj(nn.gelu_fast_approx(self.gate_proj(x)) * self.up_proj(x)) class TransformerBlock(nn.Module): @@ -169,23 +169,20 @@ class Gemma3Model(nn.Module): def __call__( self, inputs: mx.array, - inputs_embeds: mx.array = None, mask: mx.array = None, cache=None, ): - if inputs_embeds is None: - h = self.embed_tokens(inputs) - else: - h = inputs_embeds + h = self.embed_tokens(inputs) h *= self.args.hidden_size**0.5 # persistent precision issue in scaling if cache is None: cache = [None] * len(self.layers) - # Sliding window - j = self.args.sliding_window_pattern - mask = create_attention_mask(h, cache[j - 1 : j]) + if mask is None: + # Sliding window + j = self.args.sliding_window_pattern + mask = create_attention_mask(h, cache[j - 1 : j]) for layer, c in zip(self.layers, cache): h = layer(h, mask, c) @@ -194,28 +191,27 @@ class Gemma3Model(nn.Module): class Model(nn.Module): - def __init__(self, config: ModelArgs): + def __init__(self, args: ModelArgs): super().__init__() - self.config = config - self.model_type = config.model_type - self.model = Gemma3Model(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.args = args + self.model_type = args.model_type + self.model = Gemma3Model(args) + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__( self, inputs: mx.array, cache=None, - inputs_embeds=None, mask: Optional[mx.array] = None, ): - out = self.model(inputs, inputs_embeds, mask, cache) + out = self.model(inputs, mask, cache) out = self.lm_head(out) return out def sanitize(self, weights): if "lm_head.weight" not in weights: - weights["language_model.lm_head.weight"] = weights[ - "language_model.model.embed_tokens.weight" + weights["lm_head.weight"] = weights[ + "model.embed_tokens.weight" ] return { k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k @@ -245,6 +241,6 @@ class Model(nn.Module): ) else: caches.append( - RotatingKVCache() + RotatingKVCache(max_size=self.args.sliding_window, keep=0) ) return caches \ No newline at end of file