mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
fix moe conversion (#802)
This commit is contained in:
@@ -225,11 +225,11 @@ class Model(nn.Module):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n in ["up_proj", "down_proj", "gate_proj"]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
|
||||
for e in range(self.args.num_experts)
|
||||
]
|
||||
if to_join:
|
||||
if f"{prefix}.mlp.experts.0.{n}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
|
||||
for e in range(self.args.num_experts)
|
||||
]
|
||||
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
|
||||
return weights
|
||||
|
||||
|
Reference in New Issue
Block a user