mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 18:51:18 +08:00
fix mamba models conversion (#1065)
This commit is contained in:
parent
d1d480867b
commit
9000e280ae
@ -205,7 +205,7 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
for k, v in weights.items():
|
for k, v in weights.items():
|
||||||
if "conv1d.weight" in k and v.ndim == 3:
|
if "conv1d.weight" in k and v.shape[-1] != 1:
|
||||||
weights[k] = v.moveaxis(2, 1)
|
weights[k] = v.moveaxis(2, 1)
|
||||||
return weights
|
return weights
|
||||||
|
|
||||||
|
@ -440,7 +440,7 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
for k, v in weights.items():
|
for k, v in weights.items():
|
||||||
if "conv_1d.weight" in k and v.ndim == 3:
|
if "conv_1d.weight" in k and v.shape[-1] != 1:
|
||||||
weights[k] = v.moveaxis(2, 1)
|
weights[k] = v.moveaxis(2, 1)
|
||||||
if "lm_head.weight" not in weights:
|
if "lm_head.weight" not in weights:
|
||||||
self.pop("lm_head")
|
self.pop("lm_head")
|
||||||
|
Loading…
Reference in New Issue
Block a user