This commit is contained in:
Goekdeniz-Guelmez 2025-03-05 09:30:09 +01:00
parent e4c56625f0
commit a1ff1bf72a

View File

@ -91,7 +91,7 @@ class Attention(nn.Module):
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)
class OlmoeSparseMoeBlock(nn.Module): class OlmoeSparseMoeBlock(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
@ -99,13 +99,13 @@ class OlmoeSparseMoeBlock(nn.Module):
self.num_experts = args.num_experts self.num_experts = args.num_experts
self.top_k = args.num_experts_per_tok self.top_k = args.num_experts_per_tok
self.norm_topk_prob = args.norm_topk_prob self.norm_topk_prob = args.norm_topk_prob
self.gate = nn.Linear(args.hidden_size, self.num_experts, bias=False) self.gate = nn.Linear(args.hidden_size, self.num_experts, bias=False)
self.switch_mlp = SwitchGLU( self.switch_mlp = SwitchGLU(
args.hidden_size, args.hidden_size,
args.intermediate_size, args.intermediate_size,
self.num_experts, self.num_experts,
bias=args.mlp_bias bias=args.mlp_bias,
) )
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
@ -114,7 +114,9 @@ class OlmoeSparseMoeBlock(nn.Module):
router_logits = self.gate(x_flat) router_logits = self.gate(x_flat)
routing_weights = mx.softmax(router_logits, axis=1, precise=True) routing_weights = mx.softmax(router_logits, axis=1, precise=True)
k = self.top_k k = self.top_k
indices = mx.stop_gradient(mx.argpartition(-routing_weights, kth=k-1, axis=-1)[..., :k]) indices = mx.stop_gradient(
mx.argpartition(-routing_weights, kth=k - 1, axis=-1)[..., :k]
)
scores = mx.take_along_axis(routing_weights, indices, axis=-1) scores = mx.take_along_axis(routing_weights, indices, axis=-1)
if self.norm_topk_prob: if self.norm_topk_prob:
scores = scores / scores.sum(axis=-1, keepdims=True) scores = scores / scores.sum(axis=-1, keepdims=True)
@ -129,7 +131,9 @@ class TransformerBlock(nn.Module):
self.self_attn = Attention(args) self.self_attn = Attention(args)
self.mlp = OlmoeSparseMoeBlock(args) self.mlp = OlmoeSparseMoeBlock(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__( def __call__(
self, self,
@ -192,7 +196,7 @@ class Model(nn.Module):
else: else:
out = self.lm_head(out) out = self.lm_head(out)
return out return out
def sanitize(self, weights): def sanitize(self, weights):
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
return weights return weights