mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
fix mamba models conversion (#1065)
This commit is contained in:
parent
d1d480867b
commit
9000e280ae
@ -205,7 +205,7 @@ class Model(nn.Module):
|
||||
|
||||
def sanitize(self, weights):
|
||||
for k, v in weights.items():
|
||||
if "conv1d.weight" in k and v.ndim == 3:
|
||||
if "conv1d.weight" in k and v.shape[-1] != 1:
|
||||
weights[k] = v.moveaxis(2, 1)
|
||||
return weights
|
||||
|
||||
|
@ -440,7 +440,7 @@ class Model(nn.Module):
|
||||
|
||||
def sanitize(self, weights):
|
||||
for k, v in weights.items():
|
||||
if "conv_1d.weight" in k and v.ndim == 3:
|
||||
if "conv_1d.weight" in k and v.shape[-1] != 1:
|
||||
weights[k] = v.moveaxis(2, 1)
|
||||
if "lm_head.weight" not in weights:
|
||||
self.pop("lm_head")
|
||||
|
Loading…
Reference in New Issue
Block a user