mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Block sparse MM MoEs (#782)
- Adds SwitchLinear - Adds QuantizedSwitchLinear
This commit is contained in:

committed by
GitHub

parent
199df9e110
commit
9f671228cd
@@ -1,11 +1,12 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -92,7 +93,7 @@ class Attention(nn.Module):
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class Qwen2MoeMLP(nn.Module):
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
@@ -113,57 +114,32 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
self.num_experts = num_experts = args.num_experts
|
||||
self.top_k = args.num_experts_per_tok
|
||||
|
||||
# gating
|
||||
self.gate = nn.Linear(dim, num_experts, bias=False)
|
||||
self.experts = [
|
||||
Qwen2MoeMLP(dim, intermediate_size) for _ in range(self.num_experts)
|
||||
]
|
||||
self.shared_expert = Qwen2MoeMLP(dim, shared_expert_intermediate_size)
|
||||
self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts)
|
||||
|
||||
self.shared_expert = MLP(dim, shared_expert_intermediate_size)
|
||||
self.shared_expert_gate = nn.Linear(dim, 1, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
):
|
||||
ne = self.top_k
|
||||
B, L, D = x.shape
|
||||
x = x.reshape(-1, D)
|
||||
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
gates = self.gate(x)
|
||||
gates = mx.softmax(gates.astype(mx.float32), axis=-1)
|
||||
gates = mx.softmax(gates, axis=-1, precise=True)
|
||||
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne])
|
||||
k = self.top_k
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
|
||||
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||
|
||||
scores = mx.take_along_axis(gates, inds, axis=-1).astype(x.dtype)
|
||||
|
||||
if self.training:
|
||||
inds = np.array(inds)
|
||||
y = mx.zeros((B * L, ne, D), x.dtype)
|
||||
for e, expert in enumerate(self.experts):
|
||||
idx1, idx2 = map(mx.array, np.where(inds == e))
|
||||
if idx1.size == 0:
|
||||
continue
|
||||
y[idx1, idx2] = expert(x[idx1])
|
||||
|
||||
y = (y * scores[:, :, None]).sum(axis=1)
|
||||
else:
|
||||
y = []
|
||||
for xt, st, it in zip(x, scores, inds.tolist()):
|
||||
yt = mx.stack([self.experts[e](xt) for e in it], axis=-1)
|
||||
yt = (yt * st).sum(axis=-1)
|
||||
y.append(yt)
|
||||
|
||||
y = mx.stack(y, axis=0)
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
|
||||
shared_expert_output = self.shared_expert(x)
|
||||
shared_expert_output = (
|
||||
mx.sigmoid(self.shared_expert_gate(x)) * shared_expert_output
|
||||
)
|
||||
|
||||
y += shared_expert_output
|
||||
|
||||
return y.reshape(B, L, -1)
|
||||
return y + shared_expert_output
|
||||
|
||||
|
||||
class Qwen2MoeDecoderLayer(nn.Module):
|
||||
@@ -243,12 +219,19 @@ class Model(nn.Module):
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
|
||||
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
}
|
||||
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
|
||||
return weights
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n in ["up_proj", "down_proj", "gate_proj"]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
|
||||
for e in range(self.args.num_experts)
|
||||
]
|
||||
if to_join:
|
||||
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
|
Reference in New Issue
Block a user