removing custom RMSNorm class

This commit is contained in:
Goekdeniz-Guelmez 2025-01-21 22:52:45 +01:00
parent a6a92cb91f
commit c13de475f6

View File

@ -44,32 +44,6 @@ class ModelArgs(BaseModelArgs):
self.time_step_rank = math.ceil(self.hidden_size / 16)
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6, norm_before_gate=False):
super().__init__()
self.weight = mx.ones((hidden_size,))
self.variance_epsilon = eps
self.norm_before_gate = norm_before_gate
def rms_norm(self, x):
variance = mx.mean(x ** 2, axis=-1, keepdims=True)
x = x * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * x
def __call__(self, x, z=None):
if z is None:
return self.rms_norm(x)
if self.norm_before_gate:
x = self.rms_norm(x)
x = x * nn.silu(z)
else:
x = x * nn.silu(z)
x = self.rms_norm(x)
return x
def silu(x):
return x * mx.sigmoid(x)