Block sparse MM MoEs (#782)

- Adds SwitchLinear
- Adds QuantizedSwitchLinear
This commit is contained in:
Angelos Katharopoulos
2024-05-21 15:58:08 -07:00
committed by GitHub
parent 199df9e110
commit 9f671228cd
8 changed files with 365 additions and 143 deletions

View File

@@ -1,11 +1,12 @@
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs
from .switch_layers import SwitchGLU
@dataclass
@@ -92,7 +93,7 @@ class Attention(nn.Module):
return self.o_proj(output)
class Qwen2MoeMLP(nn.Module):
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
@@ -113,57 +114,32 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
self.num_experts = num_experts = args.num_experts
self.top_k = args.num_experts_per_tok
# gating
self.gate = nn.Linear(dim, num_experts, bias=False)
self.experts = [
Qwen2MoeMLP(dim, intermediate_size) for _ in range(self.num_experts)
]
self.shared_expert = Qwen2MoeMLP(dim, shared_expert_intermediate_size)
self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts)
self.shared_expert = MLP(dim, shared_expert_intermediate_size)
self.shared_expert_gate = nn.Linear(dim, 1, bias=False)
def __call__(
self,
x: mx.array,
):
ne = self.top_k
B, L, D = x.shape
x = x.reshape(-1, D)
# router_logits: (batch * sequence_length, n_experts)
gates = self.gate(x)
gates = mx.softmax(gates.astype(mx.float32), axis=-1)
gates = mx.softmax(gates, axis=-1, precise=True)
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne])
k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
scores = mx.take_along_axis(gates, inds, axis=-1)
scores = mx.take_along_axis(gates, inds, axis=-1).astype(x.dtype)
if self.training:
inds = np.array(inds)
y = mx.zeros((B * L, ne, D), x.dtype)
for e, expert in enumerate(self.experts):
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0:
continue
y[idx1, idx2] = expert(x[idx1])
y = (y * scores[:, :, None]).sum(axis=1)
else:
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.stack([self.experts[e](xt) for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt)
y = mx.stack(y, axis=0)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
shared_expert_output = self.shared_expert(x)
shared_expert_output = (
mx.sigmoid(self.shared_expert_gate(x)) * shared_expert_output
)
y += shared_expert_output
return y.reshape(B, L, -1)
return y + shared_expert_output
class Qwen2MoeDecoderLayer(nn.Module):
@@ -243,12 +219,19 @@ class Model(nn.Module):
return self.lm_head(out)
def sanitize(self, weights):
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
return weights
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n in ["up_proj", "down_proj", "gate_proj"]:
for k in ["weight", "scales", "biases"]:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
for e in range(self.args.num_experts)
]
if to_join:
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
return weights
@property
def layers(self):