From 09aaeac72caf0547aeacf2f2cac86195aa999cc9 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 31 May 2024 12:36:05 -0700 Subject: [PATCH] fix moe conversion (#802) --- llms/mlx_lm/models/mixtral.py | 12 +++++++----- llms/mlx_lm/models/phixtral.py | 10 +++++----- llms/mlx_lm/models/qwen2_moe.py | 10 +++++----- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index ee401e8e..7d1b10ac 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -202,11 +202,13 @@ class Model(nn.Module): prefix = f"model.layers.{l}" for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]: for k in ["weight", "scales", "biases"]: - to_join = [ - weights.pop(f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}") - for e in range(self.args.num_local_experts) - ] - if to_join: + if f"{prefix}.block_sparse_moe.experts.0.{n}.{k}" in weights: + to_join = [ + weights.pop( + f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}" + ) + for e in range(self.args.num_local_experts) + ] weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = ( mx.stack(to_join) ) diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index ded56c68..40a3bc4b 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -182,11 +182,11 @@ class Model(nn.Module): prefix = f"transformer.h.{l}" for n in ["fc1", "fc2"]: for k in ["weight", "scales", "biases", "bias"]: - to_join = [ - weights.pop(f"{prefix}.moe.mlp.{e}.{n}.{k}") - for e in range(self.args.num_local_experts) - ] - if to_join: + if f"{prefix}.moe.mlp.0.{n}.{k}" in weights: + to_join = [ + weights.pop(f"{prefix}.moe.mlp.{e}.{n}.{k}") + for e in range(self.args.num_local_experts) + ] weights[f"{prefix}.moe.switch_mlp.{n}.{k}"] = mx.stack(to_join) return weights diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index 1bd065aa..bba02da0 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -225,11 +225,11 @@ class Model(nn.Module): prefix = f"model.layers.{l}" for n in ["up_proj", "down_proj", "gate_proj"]: for k in ["weight", "scales", "biases"]: - to_join = [ - weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}") - for e in range(self.args.num_experts) - ] - if to_join: + if f"{prefix}.mlp.experts.0.{n}.{k}" in weights: + to_join = [ + weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}") + for e in range(self.args.num_experts) + ] weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join) return weights