mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 06:54:39 +08:00
faster generation
This commit is contained in:
parent
fd63c68280
commit
8b6beea3be
@ -124,9 +124,8 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|||||||
final_hidden_states = mx.zeros_like(x)
|
final_hidden_states = mx.zeros_like(x)
|
||||||
for expert_idx in range(self.num_experts):
|
for expert_idx in range(self.num_experts):
|
||||||
expert_weights = routing_weights[:, expert_idx:expert_idx+1]
|
expert_weights = routing_weights[:, expert_idx:expert_idx+1]
|
||||||
if mx.max(expert_weights) > 1e-5:
|
expert_output = self.experts[expert_idx](x)
|
||||||
expert_output = self.experts[expert_idx](x)
|
final_hidden_states += expert_output * expert_weights
|
||||||
final_hidden_states += expert_output * expert_weights
|
|
||||||
return final_hidden_states.reshape(B, L, D)
|
return final_hidden_states.reshape(B, L, D)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user