mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-20 02:02:41 +08:00
revert rmsnorm
This commit is contained in:
parent
0e57d38f47
commit
645b666890
@ -58,8 +58,8 @@ class Attention(nn.Module):
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
|
||||
self.q_norm = nn.RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
|
||||
self.k_norm = nn.RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
|
||||
self.q_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
|
||||
self.k_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
|
||||
self.is_sliding = (layer_idx + 1) % args.sliding_window_pattern == 0
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
@ -108,6 +108,14 @@ class Attention(nn.Module):
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
@ -118,7 +126,7 @@ class MLP(nn.Module):
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
# This should not be GELU approx, jax.nn.gelu
|
||||
return self.down_proj(nn.gelu_fast_approx(self.gate_proj(x)) * self.up_proj(x))
|
||||
return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
@ -128,14 +136,14 @@ class TransformerBlock(nn.Module):
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args, layer_idx)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.pre_feedforward_layernorm = nn.RMSNorm(
|
||||
self.pre_feedforward_layernorm = RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.post_feedforward_layernorm = nn.RMSNorm(
|
||||
self.post_feedforward_layernorm = RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
|
||||
@ -164,7 +172,7 @@ class Gemma3Model(nn.Module):
|
||||
TransformerBlock(args=args, layer_idx=layer_idx)
|
||||
for layer_idx in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user