diff --git a/llms/mlx_lm/models/olmoe.py b/llms/mlx_lm/models/olmoe.py index a2b2b44e..c67288c9 100644 --- a/llms/mlx_lm/models/olmoe.py +++ b/llms/mlx_lm/models/olmoe.py @@ -108,32 +108,18 @@ class OlmoeSparseMoeBlock(nn.Module): bias=args.mlp_bias ) - def __call__(self, x: mx.array) -> mx.array: - B, L, D = x.shape - x_flat = x.reshape(-1, D) - - # Compute routing probabilities - router_logits = self.gate(x_flat) - routing_weights = mx.softmax(router_logits, axis=1, precise=True) - - # Get top-k experts - top_k = self.top_k - indices = mx.stop_gradient(mx.argpartition(-routing_weights, kth=top_k-1, axis=-1)[..., :top_k]) - scores = mx.take_along_axis(routing_weights, indices, axis=-1) - - # Normalize probabilities (optional) - if self.norm_topk_prob: - scores = scores / scores.sum(axis=-1, keepdims=True) - - # Reshape for switch_mlp - x_reshaped = x_flat.reshape(B*L, D) - indices_reshaped = indices.reshape(B*L, top_k) - - # Apply experts and combine with routing weights - expert_outputs = self.switch_mlp(x_reshaped, indices_reshaped) - outputs = (expert_outputs * scores.reshape(B*L, top_k, 1)).sum(axis=1) - - return outputs.reshape(B, L, D) + def __call__( + self, + x: mx.array, + ): + gates = self.gate(x) + gates = mx.softmax(gates, axis=-1, precise=True) + k = self.top_k + inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k]) + scores = mx.take_along_axis(gates, inds, axis=-1) + y = self.switch_mlp(x, inds) + y = (y * scores[..., None]).sum(axis=-2) + return y class TransformerBlock(nn.Module):