diff --git a/llms/mlx_lm/models/gemma3_text.py b/llms/mlx_lm/models/gemma3_text.py index 5d7e312d..68347922 100644 --- a/llms/mlx_lm/models/gemma3_text.py +++ b/llms/mlx_lm/models/gemma3_text.py @@ -5,14 +5,14 @@ from typing import Any, Dict, List, Optional, Union import mlx.core as mx import mlx.nn as nn -from .cache import KVCache, RotatingKVCache from .base import BaseModelArgs, create_attention_mask +from .cache import KVCache, RotatingKVCache @dataclass class ModelArgs(BaseModelArgs): model_type: str - hidden_size: int =1152 + hidden_size: int = 1152 num_hidden_layers: int = 26 intermediate_size: int = 6912 num_attention_heads: int = 4 @@ -102,13 +102,13 @@ class Attention(nn.Module): if mask.shape[-1] != key_len: mask = mask[..., -key_len:] - output = mx.fast.scaled_dot_product_attention( queries, keys, values, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) + class RMSNorm(nn.Module): def __init__(self, dims: int, eps: float = 1e-5): super().__init__() @@ -118,6 +118,7 @@ class RMSNorm(nn.Module): def __call__(self, x): return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps) + class MLP(nn.Module): def __init__(self, dim, hidden_dim): super().__init__() @@ -138,9 +139,7 @@ class TransformerBlock(nn.Module): self.self_attn = Attention(args, layer_idx) self.mlp = MLP(args.hidden_size, args.intermediate_size) self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = RMSNorm( - args.hidden_size, eps=args.rms_norm_eps - ) + self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.pre_feedforward_layernorm = RMSNorm( args.hidden_size, eps=args.rms_norm_eps ) @@ -189,7 +188,6 @@ class Gemma3Model(nn.Module): cache = [None] * len(self.layers) if mask is None: - # Sliding window j = self.args.sliding_window_pattern mask = create_attention_mask(h, cache[j - 1 : j]) @@ -219,9 +217,7 @@ class Model(nn.Module): def sanitize(self, weights): if "lm_head.weight" not in weights: - weights["lm_head.weight"] = weights[ - "model.embed_tokens.weight" - ] + weights["lm_head.weight"] = weights["model.embed_tokens.weight"] return { k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k } @@ -245,11 +241,9 @@ class Model(nn.Module): i % self.args.sliding_window_pattern == self.args.sliding_window_pattern - 1 ): - caches.append( - KVCache() - ) + caches.append(KVCache()) else: caches.append( RotatingKVCache(max_size=self.args.sliding_window, keep=0) ) - return caches \ No newline at end of file + return caches