diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index ee401e8e..7d1b10ac 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -202,11 +202,13 @@ class Model(nn.Module): prefix = f"model.layers.{l}" for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]: for k in ["weight", "scales", "biases"]: - to_join = [ - weights.pop(f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}") - for e in range(self.args.num_local_experts) - ] - if to_join: + if f"{prefix}.block_sparse_moe.experts.0.{n}.{k}" in weights: + to_join = [ + weights.pop( + f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}" + ) + for e in range(self.args.num_local_experts) + ] weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = ( mx.stack(to_join) ) diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index ded56c68..40a3bc4b 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -182,11 +182,11 @@ class Model(nn.Module): prefix = f"transformer.h.{l}" for n in ["fc1", "fc2"]: for k in ["weight", "scales", "biases", "bias"]: - to_join = [ - weights.pop(f"{prefix}.moe.mlp.{e}.{n}.{k}") - for e in range(self.args.num_local_experts) - ] - if to_join: + if f"{prefix}.moe.mlp.0.{n}.{k}" in weights: + to_join = [ + weights.pop(f"{prefix}.moe.mlp.{e}.{n}.{k}") + for e in range(self.args.num_local_experts) + ] weights[f"{prefix}.moe.switch_mlp.{n}.{k}"] = mx.stack(to_join) return weights diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index 1bd065aa..bba02da0 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -225,11 +225,11 @@ class Model(nn.Module): prefix = f"model.layers.{l}" for n in ["up_proj", "down_proj", "gate_proj"]: for k in ["weight", "scales", "biases"]: - to_join = [ - weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}") - for e in range(self.args.num_experts) - ] - if to_join: + if f"{prefix}.mlp.experts.0.{n}.{k}" in weights: + to_join = [ + weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}") + for e in range(self.args.num_experts) + ] weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join) return weights