revert rmsnorm

This commit is contained in:
Prince Canuma 2025-03-12 09:41:42 +01:00
parent 0e57d38f47
commit 645b666890

View File

@ -58,8 +58,8 @@ class Attention(nn.Module):
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.q_norm = nn.RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
self.k_norm = nn.RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
self.q_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
self.k_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
self.is_sliding = (layer_idx + 1) % args.sliding_window_pattern == 0
self.rope = nn.RoPE(
@ -108,6 +108,14 @@ class Attention(nn.Module):
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def __call__(self, x):
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
@ -118,7 +126,7 @@ class MLP(nn.Module):
def __call__(self, x) -> mx.array:
# This should not be GELU approx, jax.nn.gelu
return self.down_proj(nn.gelu_fast_approx(self.gate_proj(x)) * self.up_proj(x))
return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
@ -128,14 +136,14 @@ class TransformerBlock(nn.Module):
self.hidden_size = args.hidden_size
self.self_attn = Attention(args, layer_idx)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.pre_feedforward_layernorm = nn.RMSNorm(
self.pre_feedforward_layernorm = RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.post_feedforward_layernorm = nn.RMSNorm(
self.post_feedforward_layernorm = RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
@ -164,7 +172,7 @@ class Gemma3Model(nn.Module):
TransformerBlock(args=args, layer_idx=layer_idx)
for layer_idx in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,