fix mamba models conversion (#1065)

This commit is contained in:
Awni Hannun 2024-10-22 15:44:08 -07:00 committed by GitHub
parent d1d480867b
commit 9000e280ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 2 deletions

View File

@ -205,7 +205,7 @@ class Model(nn.Module):
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3:
if "conv1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1)
return weights

View File

@ -440,7 +440,7 @@ class Model(nn.Module):
def sanitize(self, weights):
for k, v in weights.items():
if "conv_1d.weight" in k and v.ndim == 3:
if "conv_1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1)
if "lm_head.weight" not in weights:
self.pop("lm_head")