From 9000e280aeb56c2bcce128001ab157030095687a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 22 Oct 2024 15:44:08 -0700 Subject: [PATCH] fix mamba models conversion (#1065) --- llms/mlx_lm/models/mamba.py | 2 +- llms/mlx_lm/models/recurrent_gemma.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index d2740dc1..84f498e9 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -205,7 +205,7 @@ class Model(nn.Module): def sanitize(self, weights): for k, v in weights.items(): - if "conv1d.weight" in k and v.ndim == 3: + if "conv1d.weight" in k and v.shape[-1] != 1: weights[k] = v.moveaxis(2, 1) return weights diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py index 06a307a6..5595d311 100644 --- a/llms/mlx_lm/models/recurrent_gemma.py +++ b/llms/mlx_lm/models/recurrent_gemma.py @@ -440,7 +440,7 @@ class Model(nn.Module): def sanitize(self, weights): for k, v in weights.items(): - if "conv_1d.weight" in k and v.ndim == 3: + if "conv_1d.weight" in k and v.shape[-1] != 1: weights[k] = v.moveaxis(2, 1) if "lm_head.weight" not in weights: self.pop("lm_head")