fix moe conversion (#802)

This commit is contained in:
Awni Hannun 2024-05-31 12:36:05 -07:00 committed by GitHub
parent f49c5f2829
commit 09aaeac72c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 17 additions and 15 deletions

View File

@ -202,11 +202,13 @@ class Model(nn.Module):
prefix = f"model.layers.{l}"
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
for k in ["weight", "scales", "biases"]:
to_join = [
weights.pop(f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}")
for e in range(self.args.num_local_experts)
]
if to_join:
if f"{prefix}.block_sparse_moe.experts.0.{n}.{k}" in weights:
to_join = [
weights.pop(
f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}"
)
for e in range(self.args.num_local_experts)
]
weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = (
mx.stack(to_join)
)

View File

@ -182,11 +182,11 @@ class Model(nn.Module):
prefix = f"transformer.h.{l}"
for n in ["fc1", "fc2"]:
for k in ["weight", "scales", "biases", "bias"]:
to_join = [
weights.pop(f"{prefix}.moe.mlp.{e}.{n}.{k}")
for e in range(self.args.num_local_experts)
]
if to_join:
if f"{prefix}.moe.mlp.0.{n}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.moe.mlp.{e}.{n}.{k}")
for e in range(self.args.num_local_experts)
]
weights[f"{prefix}.moe.switch_mlp.{n}.{k}"] = mx.stack(to_join)
return weights

View File

@ -225,11 +225,11 @@ class Model(nn.Module):
prefix = f"model.layers.{l}"
for n in ["up_proj", "down_proj", "gate_proj"]:
for k in ["weight", "scales", "biases"]:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
for e in range(self.args.num_experts)
]
if to_join:
if f"{prefix}.mlp.experts.0.{n}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
for e in range(self.args.num_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
return weights