mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
fix moe conversion (#802)
This commit is contained in:
@@ -182,11 +182,11 @@ class Model(nn.Module):
|
||||
prefix = f"transformer.h.{l}"
|
||||
for n in ["fc1", "fc2"]:
|
||||
for k in ["weight", "scales", "biases", "bias"]:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.moe.mlp.{e}.{n}.{k}")
|
||||
for e in range(self.args.num_local_experts)
|
||||
]
|
||||
if to_join:
|
||||
if f"{prefix}.moe.mlp.0.{n}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.moe.mlp.{e}.{n}.{k}")
|
||||
for e in range(self.args.num_local_experts)
|
||||
]
|
||||
weights[f"{prefix}.moe.switch_mlp.{n}.{k}"] = mx.stack(to_join)
|
||||
return weights
|
||||
|
||||
|
Reference in New Issue
Block a user