mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 16:16:27 +08:00
a little faster and adding norm_topk_prob
This commit is contained in:
parent
4ca2cd5759
commit
e4c56625f0
@ -108,18 +108,19 @@ class OlmoeSparseMoeBlock(nn.Module):
|
||||
bias=args.mlp_bias
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
):
|
||||
gates = self.gate(x)
|
||||
gates = mx.softmax(gates, axis=-1, precise=True)
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
x_flat = x.reshape(-1, D)
|
||||
router_logits = self.gate(x_flat)
|
||||
routing_weights = mx.softmax(router_logits, axis=1, precise=True)
|
||||
k = self.top_k
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
|
||||
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||
y = self.switch_mlp(x, inds)
|
||||
indices = mx.stop_gradient(mx.argpartition(-routing_weights, kth=k-1, axis=-1)[..., :k])
|
||||
scores = mx.take_along_axis(routing_weights, indices, axis=-1)
|
||||
if self.norm_topk_prob:
|
||||
scores = scores / scores.sum(axis=-1, keepdims=True)
|
||||
y = self.switch_mlp(x_flat, indices)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
return y
|
||||
return y.reshape(B, L, D)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
|
Loading…
Reference in New Issue
Block a user