a little faster and adding norm_topk_prob

This commit is contained in:
Goekdeniz-Guelmez 2025-03-05 00:16:20 +01:00
parent 4ca2cd5759
commit e4c56625f0

View File

@ -108,18 +108,19 @@ class OlmoeSparseMoeBlock(nn.Module):
bias=args.mlp_bias bias=args.mlp_bias
) )
def __call__( def __call__(self, x: mx.array) -> mx.array:
self, B, L, D = x.shape
x: mx.array, x_flat = x.reshape(-1, D)
): router_logits = self.gate(x_flat)
gates = self.gate(x) routing_weights = mx.softmax(router_logits, axis=1, precise=True)
gates = mx.softmax(gates, axis=-1, precise=True)
k = self.top_k k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k]) indices = mx.stop_gradient(mx.argpartition(-routing_weights, kth=k-1, axis=-1)[..., :k])
scores = mx.take_along_axis(gates, inds, axis=-1) scores = mx.take_along_axis(routing_weights, indices, axis=-1)
y = self.switch_mlp(x, inds) if self.norm_topk_prob:
scores = scores / scores.sum(axis=-1, keepdims=True)
y = self.switch_mlp(x_flat, indices)
y = (y * scores[..., None]).sum(axis=-2) y = (y * scores[..., None]).sum(axis=-2)
return y return y.reshape(B, L, D)
class TransformerBlock(nn.Module): class TransformerBlock(nn.Module):