Block sparse MM MoEs (#782)

- Adds SwitchLinear
- Adds QuantizedSwitchLinear
This commit is contained in:
Angelos Katharopoulos
2024-05-21 15:58:08 -07:00
committed by GitHub
parent 199df9e110
commit 9f671228cd
8 changed files with 365 additions and 143 deletions

View File

@@ -1,11 +1,12 @@
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs
from .switch_layers import SwitchGLU
@dataclass
@@ -91,24 +92,6 @@ class MixtralAttention(nn.Module):
return self.o_proj(output)
class MixtralBLockSparseTop2MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.ffn_dim = args.intermediate_size
self.hidden_dim = args.hidden_size
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.act_fn = nn.silu
def __call__(self, x: mx.array) -> mx.array:
current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
class MixtralSparseMoeBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
@@ -120,43 +103,20 @@ class MixtralSparseMoeBlock(nn.Module):
# gating
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
self.experts = [
MixtralBLockSparseTop2MLP(args=args) for _ in range(self.num_experts)
]
self.switch_mlp = SwitchGLU(self.hidden_dim, self.ffn_dim, self.num_experts)
def __call__(self, x: mx.array) -> mx.array:
ne = self.num_experts_per_tok
orig_shape = x.shape
x = x.reshape(-1, x.shape[-1])
gates = self.gate(x)
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne])
k = self.num_experts_per_tok
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
scores = mx.take_along_axis(gates, inds, axis=-1)
scores = mx.softmax(scores, axis=-1, precise=True)
scores = mx.softmax(
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
axis=-1,
).astype(gates.dtype)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
if self.training:
inds = np.array(inds)
y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype)
for e, expert in enumerate(self.experts):
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0:
continue
y[idx1, idx2] = expert(x[idx1])
y = (y * scores[:, :, None]).sum(axis=1)
else:
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.stack([self.experts[e](xt) for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt[None, :])
y = mx.concatenate(y)
return y.reshape(orig_shape)
return y
class MixtralDecoderLayer(nn.Module):
@@ -235,6 +195,23 @@ class Model(nn.Module):
out = self.model(inputs, cache)
return self.lm_head(out)
def sanitize(self, weights):
if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights:
return weights
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
for k in ["weight", "scales", "biases"]:
to_join = [
weights.pop(f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}")
for e in range(self.args.num_local_experts)
]
if to_join:
weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = (
mx.stack(to_join)
)
return weights
@property
def layers(self):
return self.model.layers

View File

@@ -5,7 +5,8 @@ from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .switch_layers import SwitchMLP
@dataclass
@@ -75,17 +76,6 @@ class RoPEAttention(nn.Module):
return self.out_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.fc1 = nn.Linear(dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, dim)
self.act = nn.GELU(approx="precise")
def __call__(self, x) -> mx.array:
return self.fc2(self.act(self.fc1(x)))
class MOE(nn.Module):
def __init__(self, args: ModelArgs, dim: int, hidden_dim: int):
super().__init__()
@@ -93,40 +83,23 @@ class MOE(nn.Module):
self.hidden_dim = hidden_dim
self.num_experts = args.num_local_experts
self.num_experts_per_tok = args.num_experts_per_tok
self.mlp = [MLP(self.dim, self.hidden_dim) for _ in range(self.num_experts)]
self.switch_mlp = SwitchMLP(
self.dim, self.hidden_dim, self.num_experts, bias=True
)
self.gate = nn.Linear(args.model_dim, self.num_experts, bias=False)
def __call__(self, x: mx.array) -> mx.array:
ne = self.num_experts_per_tok
orig_shape = x.shape
x = x.reshape(-1, x.shape[-1])
gates = self.gate(x)
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne - 1, axis=-1))[:, :ne]
scores = mx.softmax(
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
axis=-1,
).astype(gates.dtype)
if self.training:
ys = []
y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype)
for e, expert in enumerate(self.mlp):
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0:
continue
y[idx1, idx2] = expert(x[idx1])
k = self.num_experts_per_tok
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1))[..., :k]
scores = mx.take_along_axis(gates, inds, axis=-1)
scores = mx.softmax(scores, axis=-1, precise=True)
y = (y * scores[..., None]).sum(axis=1)
else:
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.stack([self.mlp[e](xt) for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt[None, :])
y = mx.concatenate(y)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
return y.reshape(orig_shape)
return y
class ParallelBlock(nn.Module):
@@ -202,6 +175,21 @@ class Model(nn.Module):
y = self.transformer(x, mask, cache)
return self.lm_head(y)
def sanitize(self, weights):
if "transformer.h.0.moe.mlp.0.fc1.weight" not in weights:
return weights
for l in range(self.args.num_layers):
prefix = f"transformer.h.{l}"
for n in ["fc1", "fc2"]:
for k in ["weight", "scales", "biases", "bias"]:
to_join = [
weights.pop(f"{prefix}.moe.mlp.{e}.{n}.{k}")
for e in range(self.args.num_local_experts)
]
if to_join:
weights[f"{prefix}.moe.switch_mlp.{n}.{k}"] = mx.stack(to_join)
return weights
@property
def layers(self):
return self.transformer.h

View File

@@ -1,11 +1,12 @@
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs
from .switch_layers import SwitchGLU
@dataclass
@@ -92,7 +93,7 @@ class Attention(nn.Module):
return self.o_proj(output)
class Qwen2MoeMLP(nn.Module):
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
@@ -113,57 +114,32 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
self.num_experts = num_experts = args.num_experts
self.top_k = args.num_experts_per_tok
# gating
self.gate = nn.Linear(dim, num_experts, bias=False)
self.experts = [
Qwen2MoeMLP(dim, intermediate_size) for _ in range(self.num_experts)
]
self.shared_expert = Qwen2MoeMLP(dim, shared_expert_intermediate_size)
self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts)
self.shared_expert = MLP(dim, shared_expert_intermediate_size)
self.shared_expert_gate = nn.Linear(dim, 1, bias=False)
def __call__(
self,
x: mx.array,
):
ne = self.top_k
B, L, D = x.shape
x = x.reshape(-1, D)
# router_logits: (batch * sequence_length, n_experts)
gates = self.gate(x)
gates = mx.softmax(gates.astype(mx.float32), axis=-1)
gates = mx.softmax(gates, axis=-1, precise=True)
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne])
k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
scores = mx.take_along_axis(gates, inds, axis=-1)
scores = mx.take_along_axis(gates, inds, axis=-1).astype(x.dtype)
if self.training:
inds = np.array(inds)
y = mx.zeros((B * L, ne, D), x.dtype)
for e, expert in enumerate(self.experts):
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0:
continue
y[idx1, idx2] = expert(x[idx1])
y = (y * scores[:, :, None]).sum(axis=1)
else:
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.stack([self.experts[e](xt) for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt)
y = mx.stack(y, axis=0)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
shared_expert_output = self.shared_expert(x)
shared_expert_output = (
mx.sigmoid(self.shared_expert_gate(x)) * shared_expert_output
)
y += shared_expert_output
return y.reshape(B, L, -1)
return y + shared_expert_output
class Qwen2MoeDecoderLayer(nn.Module):
@@ -243,12 +219,19 @@ class Model(nn.Module):
return self.lm_head(out)
def sanitize(self, weights):
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
return weights
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n in ["up_proj", "down_proj", "gate_proj"]:
for k in ["weight", "scales", "biases"]:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
for e in range(self.args.num_experts)
]
if to_join:
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
return weights
@property
def layers(self):

View File

@@ -0,0 +1,165 @@
import math
import mlx.core as mx
import mlx.nn as nn
class QuantizedSwitchLinear(nn.Module):
def __init__(
self,
input_dims: int,
output_dims: int,
num_experts: int,
bias: bool = True,
group_size: int = 64,
bits: int = 4,
):
super().__init__()
scale = math.sqrt(1 / input_dims)
self.weight, self.scales, self.biases = mx.quantize(
mx.random.uniform(
low=-scale,
high=scale,
shape=(num_experts, output_dims, input_dims),
),
group_size=group_size,
bits=bits,
)
if bias:
self.bias = mx.zeros((num_experts, output_dims))
self.group_size = group_size
self.bits = bits
# Freeze this model's parameters
self.freeze()
def unfreeze(self, *args, **kwargs):
"""Wrap unfreeze so that we unfreeze any layers we might contain but
our parameters will remain frozen."""
super().unfreeze(*args, **kwargs)
self.freeze(recurse=False)
@property
def input_dims(self):
return self.scales.shape[2] * self.group_size
@property
def output_dims(self):
return self.weight.shape[1]
@property
def num_experts(self):
return self.weight.shape[0]
def __call__(self, x, indices):
x = mx.block_sparse_qmm(
x,
self["weight"],
self["scales"],
self["biases"],
rhs_indices=indices,
transpose=True,
group_size=self.group_size,
bits=self.bits,
)
if "bias" in self:
x = x + mx.expand_dims(self["bias"][indices], -2)
return x
class SwitchLinear(nn.Module):
def __init__(
self, input_dims: int, output_dims: int, num_experts: int, bias: bool = True
):
super().__init__()
scale = math.sqrt(1 / input_dims)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(num_experts, output_dims, input_dims),
)
if bias:
self.bias = mx.zeros((num_experts, output_dims))
@property
def input_dims(self):
return self.weight.shape[2]
@property
def output_dims(self):
return self.weight.shape[1]
@property
def num_experts(self):
return self.weight.shape[0]
def __call__(self, x, indices):
x = mx.block_sparse_mm(x, self["weight"].swapaxes(-1, -2), rhs_indices=indices)
if "bias" in self:
x = x + mx.expand_dims(self["bias"][indices], -2)
return x
def to_quantized(self, group_size: int = 64, bits: int = 4):
num_experts, output_dims, input_dims = self.weight.shape
ql = QuantizedSwitchLinear(
input_dims, output_dims, num_experts, False, group_size, bits
)
ql.weight, ql.scales, ql.biases = mx.quantize(self.weight, group_size, bits)
if "bias" in self:
ql.bias = self.bias
return ql
class SwitchGLU(nn.Module):
def __init__(
self,
input_dims: int,
hidden_dims: int,
num_experts: int,
activation=nn.silu,
bias: bool = False,
):
super().__init__()
self.gate_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
self.up_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
self.down_proj = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias)
self.activation = activation
def __call__(self, x, indices) -> mx.array:
x = mx.expand_dims(x, (-2, -3))
x_up = self.up_proj(x, indices)
x_gate = self.gate_proj(x, indices)
x = self.down_proj(self.activation(x_gate) * x_up, indices)
return x.squeeze(-2)
class SwitchMLP(nn.Module):
def __init__(
self,
input_dims: int,
hidden_dims: int,
num_experts: int,
activation=nn.gelu_approx,
bias: bool = False,
):
super().__init__()
self.fc1 = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
self.fc2 = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias)
self.activation = activation
def __call__(self, x, indices) -> mx.array:
x = mx.expand_dims(x, (-2, -3))
x = self.fc1(x, indices)
x = self.activation(x)
x = self.fc2(x, indices)
return x.squeeze(-2)