Block sparse MM MoEs (#782)

- Adds SwitchLinear
- Adds QuantizedSwitchLinear
This commit is contained in:
Angelos Katharopoulos
2024-05-21 15:58:08 -07:00
committed by GitHub
parent 199df9e110
commit 9f671228cd
8 changed files with 365 additions and 143 deletions

View File

@@ -5,7 +5,8 @@ from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .switch_layers import SwitchMLP
@dataclass
@@ -75,17 +76,6 @@ class RoPEAttention(nn.Module):
return self.out_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.fc1 = nn.Linear(dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, dim)
self.act = nn.GELU(approx="precise")
def __call__(self, x) -> mx.array:
return self.fc2(self.act(self.fc1(x)))
class MOE(nn.Module):
def __init__(self, args: ModelArgs, dim: int, hidden_dim: int):
super().__init__()
@@ -93,40 +83,23 @@ class MOE(nn.Module):
self.hidden_dim = hidden_dim
self.num_experts = args.num_local_experts
self.num_experts_per_tok = args.num_experts_per_tok
self.mlp = [MLP(self.dim, self.hidden_dim) for _ in range(self.num_experts)]
self.switch_mlp = SwitchMLP(
self.dim, self.hidden_dim, self.num_experts, bias=True
)
self.gate = nn.Linear(args.model_dim, self.num_experts, bias=False)
def __call__(self, x: mx.array) -> mx.array:
ne = self.num_experts_per_tok
orig_shape = x.shape
x = x.reshape(-1, x.shape[-1])
gates = self.gate(x)
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne - 1, axis=-1))[:, :ne]
scores = mx.softmax(
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
axis=-1,
).astype(gates.dtype)
if self.training:
ys = []
y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype)
for e, expert in enumerate(self.mlp):
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0:
continue
y[idx1, idx2] = expert(x[idx1])
k = self.num_experts_per_tok
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1))[..., :k]
scores = mx.take_along_axis(gates, inds, axis=-1)
scores = mx.softmax(scores, axis=-1, precise=True)
y = (y * scores[..., None]).sum(axis=1)
else:
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.stack([self.mlp[e](xt) for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt[None, :])
y = mx.concatenate(y)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
return y.reshape(orig_shape)
return y
class ParallelBlock(nn.Module):
@@ -202,6 +175,21 @@ class Model(nn.Module):
y = self.transformer(x, mask, cache)
return self.lm_head(y)
def sanitize(self, weights):
if "transformer.h.0.moe.mlp.0.fc1.weight" not in weights:
return weights
for l in range(self.args.num_layers):
prefix = f"transformer.h.{l}"
for n in ["fc1", "fc2"]:
for k in ["weight", "scales", "biases", "bias"]:
to_join = [
weights.pop(f"{prefix}.moe.mlp.{e}.{n}.{k}")
for e in range(self.args.num_local_experts)
]
if to_join:
weights[f"{prefix}.moe.switch_mlp.{n}.{k}"] = mx.stack(to_join)
return weights
@property
def layers(self):
return self.transformer.h