fix moe conversion (#802)

This commit is contained in:
Awni Hannun
2024-05-31 12:36:05 -07:00
committed by GitHub
parent f49c5f2829
commit 09aaeac72c
3 changed files with 17 additions and 15 deletions

View File

@@ -225,11 +225,11 @@ class Model(nn.Module):
prefix = f"model.layers.{l}"
for n in ["up_proj", "down_proj", "gate_proj"]:
for k in ["weight", "scales", "biases"]:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
for e in range(self.args.num_experts)
]
if to_join:
if f"{prefix}.mlp.experts.0.{n}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
for e in range(self.args.num_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
return weights