diff --git a/llms/mlx_lm/models/olmoe.py b/llms/mlx_lm/models/olmoe.py index f0ce097a..73d9b596 100644 --- a/llms/mlx_lm/models/olmoe.py +++ b/llms/mlx_lm/models/olmoe.py @@ -124,9 +124,8 @@ class OlmoeSparseMoeBlock(nn.Module): final_hidden_states = mx.zeros_like(x) for expert_idx in range(self.num_experts): expert_weights = routing_weights[:, expert_idx:expert_idx+1] - if mx.max(expert_weights) > 1e-5: - expert_output = self.experts[expert_idx](x) - final_hidden_states += expert_output * expert_weights + expert_output = self.experts[expert_idx](x) + final_hidden_states += expert_output * expert_weights return final_hidden_states.reshape(B, L, D)