diff --git a/llms/mlx_lm/models/olmoe.py b/llms/mlx_lm/models/olmoe.py index b16bec8d..b9c0fc69 100644 --- a/llms/mlx_lm/models/olmoe.py +++ b/llms/mlx_lm/models/olmoe.py @@ -91,7 +91,7 @@ class Attention(nn.Module): ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) - + class OlmoeSparseMoeBlock(nn.Module): def __init__(self, args: ModelArgs): @@ -99,13 +99,13 @@ class OlmoeSparseMoeBlock(nn.Module): self.num_experts = args.num_experts self.top_k = args.num_experts_per_tok self.norm_topk_prob = args.norm_topk_prob - + self.gate = nn.Linear(args.hidden_size, self.num_experts, bias=False) self.switch_mlp = SwitchGLU( - args.hidden_size, - args.intermediate_size, + args.hidden_size, + args.intermediate_size, self.num_experts, - bias=args.mlp_bias + bias=args.mlp_bias, ) def __call__(self, x: mx.array) -> mx.array: @@ -114,7 +114,9 @@ class OlmoeSparseMoeBlock(nn.Module): router_logits = self.gate(x_flat) routing_weights = mx.softmax(router_logits, axis=1, precise=True) k = self.top_k - indices = mx.stop_gradient(mx.argpartition(-routing_weights, kth=k-1, axis=-1)[..., :k]) + indices = mx.stop_gradient( + mx.argpartition(-routing_weights, kth=k - 1, axis=-1)[..., :k] + ) scores = mx.take_along_axis(routing_weights, indices, axis=-1) if self.norm_topk_prob: scores = scores / scores.sum(axis=-1, keepdims=True) @@ -129,7 +131,9 @@ class TransformerBlock(nn.Module): self.self_attn = Attention(args) self.mlp = OlmoeSparseMoeBlock(args) self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) def __call__( self, @@ -192,7 +196,7 @@ class Model(nn.Module): else: out = self.lm_head(out) return out - + def sanitize(self, weights): if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: return weights