From 645b6668908845ec7970ec0d71cf78fb5a8c5885 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 12 Mar 2025 09:41:42 +0100 Subject: [PATCH] revert rmsnorm --- llms/mlx_lm/models/gemma3_text.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/llms/mlx_lm/models/gemma3_text.py b/llms/mlx_lm/models/gemma3_text.py index 7e74e7b9..8843dc47 100644 --- a/llms/mlx_lm/models/gemma3_text.py +++ b/llms/mlx_lm/models/gemma3_text.py @@ -58,8 +58,8 @@ class Attention(nn.Module): self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) - self.q_norm = nn.RMSNorm(dims=head_dim, eps=args.rms_norm_eps) - self.k_norm = nn.RMSNorm(dims=head_dim, eps=args.rms_norm_eps) + self.q_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps) + self.k_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps) self.is_sliding = (layer_idx + 1) % args.sliding_window_pattern == 0 self.rope = nn.RoPE( @@ -108,6 +108,14 @@ class Attention(nn.Module): output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) +class RMSNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5): + super().__init__() + self.weight = mx.ones((dims,)) + self.eps = eps + + def __call__(self, x): + return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps) class MLP(nn.Module): def __init__(self, dim, hidden_dim): @@ -118,7 +126,7 @@ class MLP(nn.Module): def __call__(self, x) -> mx.array: # This should not be GELU approx, jax.nn.gelu - return self.down_proj(nn.gelu_fast_approx(self.gate_proj(x)) * self.up_proj(x)) + return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x)) class TransformerBlock(nn.Module): @@ -128,14 +136,14 @@ class TransformerBlock(nn.Module): self.hidden_size = args.hidden_size self.self_attn = Attention(args, layer_idx) self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = nn.RMSNorm( + self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( args.hidden_size, eps=args.rms_norm_eps ) - self.pre_feedforward_layernorm = nn.RMSNorm( + self.pre_feedforward_layernorm = RMSNorm( args.hidden_size, eps=args.rms_norm_eps ) - self.post_feedforward_layernorm = nn.RMSNorm( + self.post_feedforward_layernorm = RMSNorm( args.hidden_size, eps=args.rms_norm_eps ) @@ -164,7 +172,7 @@ class Gemma3Model(nn.Module): TransformerBlock(args=args, layer_idx=layer_idx) for layer_idx in range(args.num_hidden_layers) ] - self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) def __call__( self,