faster generation

This commit is contained in:
Goekdeniz-Guelmez 2025-03-04 21:26:21 +01:00
parent fd63c68280
commit 8b6beea3be

View File

@ -124,9 +124,8 @@ class OlmoeSparseMoeBlock(nn.Module):
final_hidden_states = mx.zeros_like(x)
for expert_idx in range(self.num_experts):
expert_weights = routing_weights[:, expert_idx:expert_idx+1]
if mx.max(expert_weights) > 1e-5:
expert_output = self.experts[expert_idx](x)
final_hidden_states += expert_output * expert_weights
expert_output = self.experts[expert_idx](x)
final_hidden_states += expert_output * expert_weights
return final_hidden_states.reshape(B, L, D)