mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 06:54:39 +08:00
clean up
This commit is contained in:
parent
140285080d
commit
4ca2cd5759
@ -108,32 +108,18 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|||||||
bias=args.mlp_bias
|
bias=args.mlp_bias
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(
|
||||||
B, L, D = x.shape
|
self,
|
||||||
x_flat = x.reshape(-1, D)
|
x: mx.array,
|
||||||
|
):
|
||||||
# Compute routing probabilities
|
gates = self.gate(x)
|
||||||
router_logits = self.gate(x_flat)
|
gates = mx.softmax(gates, axis=-1, precise=True)
|
||||||
routing_weights = mx.softmax(router_logits, axis=1, precise=True)
|
k = self.top_k
|
||||||
|
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
|
||||||
# Get top-k experts
|
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||||
top_k = self.top_k
|
y = self.switch_mlp(x, inds)
|
||||||
indices = mx.stop_gradient(mx.argpartition(-routing_weights, kth=top_k-1, axis=-1)[..., :top_k])
|
y = (y * scores[..., None]).sum(axis=-2)
|
||||||
scores = mx.take_along_axis(routing_weights, indices, axis=-1)
|
return y
|
||||||
|
|
||||||
# Normalize probabilities (optional)
|
|
||||||
if self.norm_topk_prob:
|
|
||||||
scores = scores / scores.sum(axis=-1, keepdims=True)
|
|
||||||
|
|
||||||
# Reshape for switch_mlp
|
|
||||||
x_reshaped = x_flat.reshape(B*L, D)
|
|
||||||
indices_reshaped = indices.reshape(B*L, top_k)
|
|
||||||
|
|
||||||
# Apply experts and combine with routing weights
|
|
||||||
expert_outputs = self.switch_mlp(x_reshaped, indices_reshaped)
|
|
||||||
outputs = (expert_outputs * scores.reshape(B*L, top_k, 1)).sum(axis=1)
|
|
||||||
|
|
||||||
return outputs.reshape(B, L, D)
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerBlock(nn.Module):
|
class TransformerBlock(nn.Module):
|
||||||
|
Loading…
Reference in New Issue
Block a user