mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-21 11:01:13 +08:00
revert rmsnorm
This commit is contained in:
parent
0e57d38f47
commit
645b666890
@ -58,8 +58,8 @@ class Attention(nn.Module):
|
|||||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||||
|
|
||||||
self.q_norm = nn.RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
|
self.q_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
|
||||||
self.k_norm = nn.RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
|
self.k_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
|
||||||
self.is_sliding = (layer_idx + 1) % args.sliding_window_pattern == 0
|
self.is_sliding = (layer_idx + 1) % args.sliding_window_pattern == 0
|
||||||
|
|
||||||
self.rope = nn.RoPE(
|
self.rope = nn.RoPE(
|
||||||
@ -108,6 +108,14 @@ class Attention(nn.Module):
|
|||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dims: int, eps: float = 1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = mx.ones((dims,))
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
def __init__(self, dim, hidden_dim):
|
def __init__(self, dim, hidden_dim):
|
||||||
@ -118,7 +126,7 @@ class MLP(nn.Module):
|
|||||||
|
|
||||||
def __call__(self, x) -> mx.array:
|
def __call__(self, x) -> mx.array:
|
||||||
# This should not be GELU approx, jax.nn.gelu
|
# This should not be GELU approx, jax.nn.gelu
|
||||||
return self.down_proj(nn.gelu_fast_approx(self.gate_proj(x)) * self.up_proj(x))
|
return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
||||||
class TransformerBlock(nn.Module):
|
class TransformerBlock(nn.Module):
|
||||||
@ -128,14 +136,14 @@ class TransformerBlock(nn.Module):
|
|||||||
self.hidden_size = args.hidden_size
|
self.hidden_size = args.hidden_size
|
||||||
self.self_attn = Attention(args, layer_idx)
|
self.self_attn = Attention(args, layer_idx)
|
||||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
self.post_attention_layernorm = nn.RMSNorm(
|
self.post_attention_layernorm = RMSNorm(
|
||||||
args.hidden_size, eps=args.rms_norm_eps
|
args.hidden_size, eps=args.rms_norm_eps
|
||||||
)
|
)
|
||||||
self.pre_feedforward_layernorm = nn.RMSNorm(
|
self.pre_feedforward_layernorm = RMSNorm(
|
||||||
args.hidden_size, eps=args.rms_norm_eps
|
args.hidden_size, eps=args.rms_norm_eps
|
||||||
)
|
)
|
||||||
self.post_feedforward_layernorm = nn.RMSNorm(
|
self.post_feedforward_layernorm = RMSNorm(
|
||||||
args.hidden_size, eps=args.rms_norm_eps
|
args.hidden_size, eps=args.rms_norm_eps
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -164,7 +172,7 @@ class Gemma3Model(nn.Module):
|
|||||||
TransformerBlock(args=args, layer_idx=layer_idx)
|
TransformerBlock(args=args, layer_idx=layer_idx)
|
||||||
for layer_idx in range(args.num_hidden_layers)
|
for layer_idx in range(args.num_hidden_layers)
|
||||||
]
|
]
|
||||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user