From e4c56625f0e73a4760ce3f2956a734c836d79a58 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 5 Mar 2025 00:16:20 +0100 Subject: [PATCH] a little faster and adding norm_topk_prob --- llms/mlx_lm/models/olmoe.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/llms/mlx_lm/models/olmoe.py b/llms/mlx_lm/models/olmoe.py index c67288c9..b16bec8d 100644 --- a/llms/mlx_lm/models/olmoe.py +++ b/llms/mlx_lm/models/olmoe.py @@ -108,18 +108,19 @@ class OlmoeSparseMoeBlock(nn.Module): bias=args.mlp_bias ) - def __call__( - self, - x: mx.array, - ): - gates = self.gate(x) - gates = mx.softmax(gates, axis=-1, precise=True) + def __call__(self, x: mx.array) -> mx.array: + B, L, D = x.shape + x_flat = x.reshape(-1, D) + router_logits = self.gate(x_flat) + routing_weights = mx.softmax(router_logits, axis=1, precise=True) k = self.top_k - inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k]) - scores = mx.take_along_axis(gates, inds, axis=-1) - y = self.switch_mlp(x, inds) + indices = mx.stop_gradient(mx.argpartition(-routing_weights, kth=k-1, axis=-1)[..., :k]) + scores = mx.take_along_axis(routing_weights, indices, axis=-1) + if self.norm_topk_prob: + scores = scores / scores.sum(axis=-1, keepdims=True) + y = self.switch_mlp(x_flat, indices) y = (y * scores[..., None]).sum(axis=-2) - return y + return y.reshape(B, L, D) class TransformerBlock(nn.Module):