import inspect from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union import mlx.core as mx import mlx.nn as nn from .cache import KVCache, RotatingKVCache from .base import BaseModelArgs, create_attention_mask @dataclass class ModelArgs(BaseModelArgs): model_type: str hidden_size: int =1152 num_hidden_layers: int = 26 intermediate_size: int = 6912 num_attention_heads: int = 4 head_dim: int = 256 rms_norm_eps: float = 1.0e-6 vocab_size: int = 262144 num_key_value_heads: int = 1 rope_global_base_freq: float = 1_000_000.0 rope_local_base_freq: float = 10_000.0 rope_traditional: bool = False query_pre_attn_scalar: float = 256 sliding_window: int = 512 rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None mm_tokens_per_image: int = 256 sliding_window_pattern: int = 6 @classmethod def from_dict(cls, params): return cls( **{ k: v for k, v in params.items() if k in inspect.signature(cls).parameters } ) class Attention(nn.Module): def __init__(self, args: ModelArgs, layer_idx: int): super().__init__() dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.repeats = n_heads // n_kv_heads self.head_dim = head_dim = args.head_dim self.layer_idx = layer_idx self.scale = args.query_pre_attn_scalar**-0.5 self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) self.q_norm = nn.RMSNorm(dims=head_dim, eps=args.rms_norm_eps) self.k_norm = nn.RMSNorm(dims=head_dim, eps=args.rms_norm_eps) self.is_sliding = (layer_idx + 1) % args.sliding_window_pattern == 0 self.rope = nn.RoPE( head_dim, traditional=args.rope_traditional, base=( args.rope_local_base_freq if self.is_sliding else args.rope_global_base_freq ), ) def __call__( self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, ) -> mx.array: B, L, _ = x.shape queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) queries = self.q_norm(queries) keys = self.k_norm(keys) if cache is not None: queries = self.rope(queries, offset=cache.offset) keys = self.rope(keys, offset=cache.offset) keys, values = cache.update_and_fetch(keys, values) else: queries = self.rope(queries) keys = self.rope(keys) # Sliding window if self.is_sliding and mask is not None: key_len = keys.shape[-2] if mask.shape[-1] != key_len: mask = mask[..., :key_len] output = mx.fast.scaled_dot_product_attention( queries, keys, values, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) class MLP(nn.Module): def __init__(self, dim, hidden_dim): super().__init__() self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) self.down_proj = nn.Linear(hidden_dim, dim, bias=False) self.up_proj = nn.Linear(dim, hidden_dim, bias=False) def __call__(self, x) -> mx.array: # This should not be GELU approx, jax.nn.gelu return self.down_proj(nn.gelu_fast_approx(self.gate_proj(x)) * self.up_proj(x)) class TransformerBlock(nn.Module): def __init__(self, args: ModelArgs, layer_idx: int): super().__init__() self.num_attention_heads = args.num_attention_heads self.hidden_size = args.hidden_size self.self_attn = Attention(args, layer_idx) self.mlp = MLP(args.hidden_size, args.intermediate_size) self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.post_attention_layernorm = nn.RMSNorm( args.hidden_size, eps=args.rms_norm_eps ) self.pre_feedforward_layernorm = nn.RMSNorm( args.hidden_size, eps=args.rms_norm_eps ) self.post_feedforward_layernorm = nn.RMSNorm( args.hidden_size, eps=args.rms_norm_eps ) def __call__( self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + self.post_attention_layernorm(r) r = self.mlp(self.pre_feedforward_layernorm(h)) out = h + self.post_feedforward_layernorm(r) return out class Gemma3Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args self.vocab_size = args.vocab_size self.num_hidden_layers = args.num_hidden_layers assert self.vocab_size > 0 self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) self.layers = [ TransformerBlock(args=args, layer_idx=layer_idx) for layer_idx in range(args.num_hidden_layers) ] self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) def __call__( self, inputs: mx.array, mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) h *= self.args.hidden_size**0.5 # persistent precision issue in scaling if cache is None: cache = [None] * len(self.layers) if mask is None: # Sliding window j = self.args.sliding_window_pattern mask = create_attention_mask(h, cache[j - 1 : j]) for layer, c in zip(self.layers, cache): h = layer(h, mask, c) return self.norm(h) class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args self.model_type = args.model_type self.model = Gemma3Model(args) self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__( self, inputs: mx.array, cache=None, mask: Optional[mx.array] = None, ): out = self.model(inputs, mask, cache) out = self.lm_head(out) return out def sanitize(self, weights): if "lm_head.weight" not in weights: weights["lm_head.weight"] = weights[ "model.embed_tokens.weight" ] return { k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k } @property def layers(self): return self.model.layers @property def head_dim(self): return self.args.head_dim @property def n_kv_heads(self): return self.args.num_key_value_heads def make_cache(self): caches = [] for i in range(self.args.num_hidden_layers): if ( i % self.args.sliding_window_pattern == self.args.sliding_window_pattern - 1 ): caches.append( KVCache() ) else: caches.append( RotatingKVCache(max_size=self.args.sliding_window, keep=0) ) return caches