fix moe conversion (#802)

This commit is contained in:
Awni Hannun 2024-05-31 12:36:05 -07:00 committed by GitHub
parent f49c5f2829
commit 09aaeac72c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 17 additions and 15 deletions

View File

@ -202,11 +202,13 @@ class Model(nn.Module):
prefix = f"model.layers.{l}" prefix = f"model.layers.{l}"
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]: for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
for k in ["weight", "scales", "biases"]: for k in ["weight", "scales", "biases"]:
to_join = [ if f"{prefix}.block_sparse_moe.experts.0.{n}.{k}" in weights:
weights.pop(f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}") to_join = [
for e in range(self.args.num_local_experts) weights.pop(
] f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}"
if to_join: )
for e in range(self.args.num_local_experts)
]
weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = ( weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = (
mx.stack(to_join) mx.stack(to_join)
) )

View File

@ -182,11 +182,11 @@ class Model(nn.Module):
prefix = f"transformer.h.{l}" prefix = f"transformer.h.{l}"
for n in ["fc1", "fc2"]: for n in ["fc1", "fc2"]:
for k in ["weight", "scales", "biases", "bias"]: for k in ["weight", "scales", "biases", "bias"]:
to_join = [ if f"{prefix}.moe.mlp.0.{n}.{k}" in weights:
weights.pop(f"{prefix}.moe.mlp.{e}.{n}.{k}") to_join = [
for e in range(self.args.num_local_experts) weights.pop(f"{prefix}.moe.mlp.{e}.{n}.{k}")
] for e in range(self.args.num_local_experts)
if to_join: ]
weights[f"{prefix}.moe.switch_mlp.{n}.{k}"] = mx.stack(to_join) weights[f"{prefix}.moe.switch_mlp.{n}.{k}"] = mx.stack(to_join)
return weights return weights

View File

@ -225,11 +225,11 @@ class Model(nn.Module):
prefix = f"model.layers.{l}" prefix = f"model.layers.{l}"
for n in ["up_proj", "down_proj", "gate_proj"]: for n in ["up_proj", "down_proj", "gate_proj"]:
for k in ["weight", "scales", "biases"]: for k in ["weight", "scales", "biases"]:
to_join = [ if f"{prefix}.mlp.experts.0.{n}.{k}" in weights:
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}") to_join = [
for e in range(self.args.num_experts) weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
] for e in range(self.args.num_experts)
if to_join: ]
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join) weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
return weights return weights