diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index d2740dc1..84f498e9 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -205,7 +205,7 @@ class Model(nn.Module): def sanitize(self, weights): for k, v in weights.items(): - if "conv1d.weight" in k and v.ndim == 3: + if "conv1d.weight" in k and v.shape[-1] != 1: weights[k] = v.moveaxis(2, 1) return weights diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py index 06a307a6..5595d311 100644 --- a/llms/mlx_lm/models/recurrent_gemma.py +++ b/llms/mlx_lm/models/recurrent_gemma.py @@ -440,7 +440,7 @@ class Model(nn.Module): def sanitize(self, weights): for k, v in weights.items(): - if "conv_1d.weight" in k and v.ndim == 3: + if "conv_1d.weight" in k and v.shape[-1] != 1: weights[k] = v.moveaxis(2, 1) if "lm_head.weight" not in weights: self.pop("lm_head")