This commit is contained in:
Goekdeniz-Guelmez 2025-03-05 09:30:09 +01:00
parent e4c56625f0
commit a1ff1bf72a

View File

@ -105,7 +105,7 @@ class OlmoeSparseMoeBlock(nn.Module):
args.hidden_size, args.hidden_size,
args.intermediate_size, args.intermediate_size,
self.num_experts, self.num_experts,
bias=args.mlp_bias bias=args.mlp_bias,
) )
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
@ -114,7 +114,9 @@ class OlmoeSparseMoeBlock(nn.Module):
router_logits = self.gate(x_flat) router_logits = self.gate(x_flat)
routing_weights = mx.softmax(router_logits, axis=1, precise=True) routing_weights = mx.softmax(router_logits, axis=1, precise=True)
k = self.top_k k = self.top_k
indices = mx.stop_gradient(mx.argpartition(-routing_weights, kth=k-1, axis=-1)[..., :k]) indices = mx.stop_gradient(
mx.argpartition(-routing_weights, kth=k - 1, axis=-1)[..., :k]
)
scores = mx.take_along_axis(routing_weights, indices, axis=-1) scores = mx.take_along_axis(routing_weights, indices, axis=-1)
if self.norm_topk_prob: if self.norm_topk_prob:
scores = scores / scores.sum(axis=-1, keepdims=True) scores = scores / scores.sum(axis=-1, keepdims=True)
@ -129,7 +131,9 @@ class TransformerBlock(nn.Module):
self.self_attn = Attention(args) self.self_attn = Attention(args)
self.mlp = OlmoeSparseMoeBlock(args) self.mlp = OlmoeSparseMoeBlock(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__( def __call__(
self, self,