diff --git a/llms/mlx_lm/models/olmoe.py b/llms/mlx_lm/models/olmoe.py index 1297ce83..a2b2b44e 100644 --- a/llms/mlx_lm/models/olmoe.py +++ b/llms/mlx_lm/models/olmoe.py @@ -8,6 +8,7 @@ import mlx.nn as nn from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .rope_utils import initialize_rope +from .switch_layers import SwitchGLU @dataclass @@ -48,7 +49,7 @@ class Attention(nn.Module): self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads self.scale = head_dim**-0.5 - + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) @@ -90,17 +91,6 @@ class Attention(nn.Module): ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) - - -class MLP(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.gate_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=args.mlp_bias) - self.down_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.mlp_bias) - self.up_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=args.mlp_bias) - - def __call__(self, x) -> mx.array: - return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) class OlmoeSparseMoeBlock(nn.Module): @@ -109,20 +99,41 @@ class OlmoeSparseMoeBlock(nn.Module): self.num_experts = args.num_experts self.top_k = args.num_experts_per_tok self.norm_topk_prob = args.norm_topk_prob + self.gate = nn.Linear(args.hidden_size, self.num_experts, bias=False) - self.experts = [MLP(args) for _ in range(self.num_experts)] + self.switch_mlp = SwitchGLU( + args.hidden_size, + args.intermediate_size, + self.num_experts, + bias=args.mlp_bias + ) def __call__(self, x: mx.array) -> mx.array: B, L, D = x.shape - x = x.reshape(-1, D) - router_logits = self.gate(x) + x_flat = x.reshape(-1, D) + + # Compute routing probabilities + router_logits = self.gate(x_flat) routing_weights = mx.softmax(router_logits, axis=1, precise=True) - final_hidden_states = mx.zeros_like(x) - for expert_idx in range(self.num_experts): - expert_weights = routing_weights[:, expert_idx:expert_idx+1] - expert_output = self.experts[expert_idx](x) - final_hidden_states += expert_output * expert_weights - return final_hidden_states.reshape(B, L, D) + + # Get top-k experts + top_k = self.top_k + indices = mx.stop_gradient(mx.argpartition(-routing_weights, kth=top_k-1, axis=-1)[..., :top_k]) + scores = mx.take_along_axis(routing_weights, indices, axis=-1) + + # Normalize probabilities (optional) + if self.norm_topk_prob: + scores = scores / scores.sum(axis=-1, keepdims=True) + + # Reshape for switch_mlp + x_reshaped = x_flat.reshape(B*L, D) + indices_reshaped = indices.reshape(B*L, top_k) + + # Apply experts and combine with routing weights + expert_outputs = self.switch_mlp(x_reshaped, indices_reshaped) + outputs = (expert_outputs * scores.reshape(B*L, top_k, 1)).sum(axis=1) + + return outputs.reshape(B, L, D) class TransformerBlock(nn.Module): @@ -194,6 +205,21 @@ class Model(nn.Module): else: out = self.lm_head(out) return out + + def sanitize(self, weights): + if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: + return weights + for l in range(self.args.num_hidden_layers): + prefix = f"model.layers.{l}" + for n in ["up_proj", "down_proj", "gate_proj"]: + for k in ["weight", "scales", "biases"]: + if f"{prefix}.mlp.experts.0.{n}.{k}" in weights: + to_join = [ + weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}") + for e in range(self.args.num_experts) + ] + weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join) + return weights @property def layers(self):