From c13de475f6d2646efb1b18375fbe20cf0c4dd784 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 21 Jan 2025 22:52:45 +0100 Subject: [PATCH] removing custom RMSNorm class --- llms/mlx_lm/models/mamba2.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index e5a9133f..432ab994 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -44,32 +44,6 @@ class ModelArgs(BaseModelArgs): self.time_step_rank = math.ceil(self.hidden_size / 16) -class MambaRMSNormGated(nn.Module): - def __init__(self, hidden_size, eps=1e-6, norm_before_gate=False): - super().__init__() - self.weight = mx.ones((hidden_size,)) - self.variance_epsilon = eps - self.norm_before_gate = norm_before_gate - - def rms_norm(self, x): - variance = mx.mean(x ** 2, axis=-1, keepdims=True) - x = x * mx.rsqrt(variance + self.variance_epsilon) - return self.weight * x - - def __call__(self, x, z=None): - if z is None: - return self.rms_norm(x) - - if self.norm_before_gate: - x = self.rms_norm(x) - x = x * nn.silu(z) - else: - x = x * nn.silu(z) - x = self.rms_norm(x) - - return x - - def silu(x): return x * mx.sigmoid(x)