This commit is contained in:
Goekdeniz-Guelmez 2025-03-05 09:30:09 +01:00
parent e4c56625f0
commit a1ff1bf72a

View File

@ -91,7 +91,7 @@ class Attention(nn.Module):
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class OlmoeSparseMoeBlock(nn.Module):
def __init__(self, args: ModelArgs):
@ -99,13 +99,13 @@ class OlmoeSparseMoeBlock(nn.Module):
self.num_experts = args.num_experts
self.top_k = args.num_experts_per_tok
self.norm_topk_prob = args.norm_topk_prob
self.gate = nn.Linear(args.hidden_size, self.num_experts, bias=False)
self.switch_mlp = SwitchGLU(
args.hidden_size,
args.intermediate_size,
args.hidden_size,
args.intermediate_size,
self.num_experts,
bias=args.mlp_bias
bias=args.mlp_bias,
)
def __call__(self, x: mx.array) -> mx.array:
@ -114,7 +114,9 @@ class OlmoeSparseMoeBlock(nn.Module):
router_logits = self.gate(x_flat)
routing_weights = mx.softmax(router_logits, axis=1, precise=True)
k = self.top_k
indices = mx.stop_gradient(mx.argpartition(-routing_weights, kth=k-1, axis=-1)[..., :k])
indices = mx.stop_gradient(
mx.argpartition(-routing_weights, kth=k - 1, axis=-1)[..., :k]
)
scores = mx.take_along_axis(routing_weights, indices, axis=-1)
if self.norm_topk_prob:
scores = scores / scores.sum(axis=-1, keepdims=True)
@ -129,7 +131,9 @@ class TransformerBlock(nn.Module):
self.self_attn = Attention(args)
self.mlp = OlmoeSparseMoeBlock(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__(
self,
@ -192,7 +196,7 @@ class Model(nn.Module):
else:
out = self.lm_head(out)
return out
def sanitize(self, weights):
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
return weights