From 2462a34194279bbb2d8bffa4d2da4db12343c38c Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 22 Jan 2025 22:30:15 +0100 Subject: [PATCH] removig sanitize --- llms/mlx_lm/models/mamba2.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 888fb4fa..7302bde5 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -24,8 +24,6 @@ class ModelArgs(BaseModelArgs): use_conv_bias: bool initializer_range: float residual_in_fp32: bool - rescale_prenorm_residual: bool - rms_norm: bool chunk_size: int tie_word_embeddings: bool time_step_limit: Tuple[float, float] @@ -162,7 +160,8 @@ class Mamba2Block(nn.Module): if cache is None: cache = [None, None] - conv_state, _ = cache + else: + conv_state, ssm_state = cache zxBCdt = self.in_proj(u) @@ -173,7 +172,7 @@ class Mamba2Block(nn.Module): ) xBC, conv_state = self.conv1d(xBC, conv_state) - xBC =xBC * mx.sigmoid(xBC) + xBC = xBC * mx.sigmoid(xBC) xBC = xBC[:, :seq_len, :] x, B, C = mx.split( @@ -187,7 +186,7 @@ class Mamba2Block(nn.Module): C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1)) A = -mx.exp(self.A_log) - y, next_state = ssd_forward_attn( + y, next_ssm_state = ssd_forward_attn( x=x, dt=dt, A=A, @@ -209,8 +208,7 @@ class Mamba2Block(nn.Module): y = self.out_proj(y) cache[0] = conv_state - cache[1] = next_state - + cache[1] = next_ssm_state return y @@ -267,12 +265,6 @@ class Model(nn.Module): logits = self.lm_head(hidden) return logits - - def sanitize(self, weights): - for k, v in weights.items(): - if "conv1d.weight" in k and v.shape[-1] != 1: - weights[k] = v.moveaxis(2, 1) - return weights def make_cache(self): return [MambaCache() for _ in range(len(self.layers))]