mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-28 20:25:22 +08:00
removing custom RMSNorm class
This commit is contained in:
parent
a6a92cb91f
commit
c13de475f6
@ -44,32 +44,6 @@ class ModelArgs(BaseModelArgs):
|
|||||||
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
||||||
|
|
||||||
|
|
||||||
class MambaRMSNormGated(nn.Module):
|
|
||||||
def __init__(self, hidden_size, eps=1e-6, norm_before_gate=False):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = mx.ones((hidden_size,))
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
self.norm_before_gate = norm_before_gate
|
|
||||||
|
|
||||||
def rms_norm(self, x):
|
|
||||||
variance = mx.mean(x ** 2, axis=-1, keepdims=True)
|
|
||||||
x = x * mx.rsqrt(variance + self.variance_epsilon)
|
|
||||||
return self.weight * x
|
|
||||||
|
|
||||||
def __call__(self, x, z=None):
|
|
||||||
if z is None:
|
|
||||||
return self.rms_norm(x)
|
|
||||||
|
|
||||||
if self.norm_before_gate:
|
|
||||||
x = self.rms_norm(x)
|
|
||||||
x = x * nn.silu(z)
|
|
||||||
else:
|
|
||||||
x = x * nn.silu(z)
|
|
||||||
x = self.rms_norm(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def silu(x):
|
def silu(x):
|
||||||
return x * mx.sigmoid(x)
|
return x * mx.sigmoid(x)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user