mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 03:40:22 +08:00
adding SwitchGLU
This commit is contained in:
parent
230abad426
commit
140285080d
@ -8,6 +8,7 @@ import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .rope_utils import initialize_rope
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -48,7 +49,7 @@ class Attention(nn.Module):
|
||||
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
|
||||
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
|
||||
@ -90,17 +91,6 @@ class Attention(nn.Module):
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=args.mlp_bias)
|
||||
self.down_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.mlp_bias)
|
||||
self.up_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=args.mlp_bias)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class OlmoeSparseMoeBlock(nn.Module):
|
||||
@ -109,20 +99,41 @@ class OlmoeSparseMoeBlock(nn.Module):
|
||||
self.num_experts = args.num_experts
|
||||
self.top_k = args.num_experts_per_tok
|
||||
self.norm_topk_prob = args.norm_topk_prob
|
||||
|
||||
self.gate = nn.Linear(args.hidden_size, self.num_experts, bias=False)
|
||||
self.experts = [MLP(args) for _ in range(self.num_experts)]
|
||||
self.switch_mlp = SwitchGLU(
|
||||
args.hidden_size,
|
||||
args.intermediate_size,
|
||||
self.num_experts,
|
||||
bias=args.mlp_bias
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
x = x.reshape(-1, D)
|
||||
router_logits = self.gate(x)
|
||||
x_flat = x.reshape(-1, D)
|
||||
|
||||
# Compute routing probabilities
|
||||
router_logits = self.gate(x_flat)
|
||||
routing_weights = mx.softmax(router_logits, axis=1, precise=True)
|
||||
final_hidden_states = mx.zeros_like(x)
|
||||
for expert_idx in range(self.num_experts):
|
||||
expert_weights = routing_weights[:, expert_idx:expert_idx+1]
|
||||
expert_output = self.experts[expert_idx](x)
|
||||
final_hidden_states += expert_output * expert_weights
|
||||
return final_hidden_states.reshape(B, L, D)
|
||||
|
||||
# Get top-k experts
|
||||
top_k = self.top_k
|
||||
indices = mx.stop_gradient(mx.argpartition(-routing_weights, kth=top_k-1, axis=-1)[..., :top_k])
|
||||
scores = mx.take_along_axis(routing_weights, indices, axis=-1)
|
||||
|
||||
# Normalize probabilities (optional)
|
||||
if self.norm_topk_prob:
|
||||
scores = scores / scores.sum(axis=-1, keepdims=True)
|
||||
|
||||
# Reshape for switch_mlp
|
||||
x_reshaped = x_flat.reshape(B*L, D)
|
||||
indices_reshaped = indices.reshape(B*L, top_k)
|
||||
|
||||
# Apply experts and combine with routing weights
|
||||
expert_outputs = self.switch_mlp(x_reshaped, indices_reshaped)
|
||||
outputs = (expert_outputs * scores.reshape(B*L, top_k, 1)).sum(axis=1)
|
||||
|
||||
return outputs.reshape(B, L, D)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
@ -194,6 +205,21 @@ class Model(nn.Module):
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
|
||||
return weights
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n in ["up_proj", "down_proj", "gate_proj"]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
if f"{prefix}.mlp.experts.0.{n}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
|
||||
for e in range(self.args.num_experts)
|
||||
]
|
||||
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
|
Loading…
Reference in New Issue
Block a user