fix mamba models conversion (#1065)

This commit is contained in:
Awni Hannun
2024-10-22 15:44:08 -07:00
committed by GitHub
parent d1d480867b
commit 9000e280ae
2 changed files with 2 additions and 2 deletions

View File

@@ -205,7 +205,7 @@ class Model(nn.Module):
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3:
if "conv1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1)
return weights

View File

@@ -440,7 +440,7 @@ class Model(nn.Module):
def sanitize(self, weights):
for k, v in weights.items():
if "conv_1d.weight" in k and v.ndim == 3:
if "conv_1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1)
if "lm_head.weight" not in weights:
self.pop("lm_head")