Block sparse MM MoEs (#782)

- Adds SwitchLinear
- Adds QuantizedSwitchLinear
This commit is contained in:
Angelos Katharopoulos 2024-05-21 15:58:08 -07:00 committed by GitHub
parent 199df9e110
commit 9f671228cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 365 additions and 143 deletions

View File

@ -7,7 +7,7 @@ from mlx.utils import tree_flatten, tree_unflatten
from .gguf import convert_to_gguf from .gguf import convert_to_gguf
from .tuner.dora import DoRALinear from .tuner.dora import DoRALinear
from .tuner.lora import LoRALinear from .tuner.lora import LoRALinear, LoRASwitchLinear
from .tuner.utils import apply_lora_layers, dequantize from .tuner.utils import apply_lora_layers, dequantize
from .utils import ( from .utils import (
fetch_from_hub, fetch_from_hub,
@ -82,7 +82,7 @@ def main() -> None:
fused_linears = [ fused_linears = [
(n, m.to_linear()) (n, m.to_linear())
for n, m in model.named_modules() for n, m in model.named_modules()
if isinstance(m, (LoRALinear, DoRALinear)) if isinstance(m, (LoRASwitchLinear, LoRALinear, DoRALinear))
] ]
model.update_modules(tree_unflatten(fused_linears)) model.update_modules(tree_unflatten(fused_linears))

View File

@ -1,11 +1,12 @@
import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs from .base import BaseModelArgs
from .switch_layers import SwitchGLU
@dataclass @dataclass
@ -91,24 +92,6 @@ class MixtralAttention(nn.Module):
return self.o_proj(output) return self.o_proj(output)
class MixtralBLockSparseTop2MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.ffn_dim = args.intermediate_size
self.hidden_dim = args.hidden_size
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.act_fn = nn.silu
def __call__(self, x: mx.array) -> mx.array:
current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
class MixtralSparseMoeBlock(nn.Module): class MixtralSparseMoeBlock(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
@ -120,43 +103,20 @@ class MixtralSparseMoeBlock(nn.Module):
# gating # gating
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
self.experts = [ self.switch_mlp = SwitchGLU(self.hidden_dim, self.ffn_dim, self.num_experts)
MixtralBLockSparseTop2MLP(args=args) for _ in range(self.num_experts)
]
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
ne = self.num_experts_per_tok
orig_shape = x.shape
x = x.reshape(-1, x.shape[-1])
gates = self.gate(x) gates = self.gate(x)
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne]) k = self.num_experts_per_tok
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
scores = mx.take_along_axis(gates, inds, axis=-1)
scores = mx.softmax(scores, axis=-1, precise=True)
scores = mx.softmax( y = self.switch_mlp(x, inds)
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32), y = (y * scores[..., None]).sum(axis=-2)
axis=-1,
).astype(gates.dtype)
if self.training: return y
inds = np.array(inds)
y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype)
for e, expert in enumerate(self.experts):
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0:
continue
y[idx1, idx2] = expert(x[idx1])
y = (y * scores[:, :, None]).sum(axis=1)
else:
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.stack([self.experts[e](xt) for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt[None, :])
y = mx.concatenate(y)
return y.reshape(orig_shape)
class MixtralDecoderLayer(nn.Module): class MixtralDecoderLayer(nn.Module):
@ -235,6 +195,23 @@ class Model(nn.Module):
out = self.model(inputs, cache) out = self.model(inputs, cache)
return self.lm_head(out) return self.lm_head(out)
def sanitize(self, weights):
if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights:
return weights
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
for k in ["weight", "scales", "biases"]:
to_join = [
weights.pop(f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}")
for e in range(self.args.num_local_experts)
]
if to_join:
weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = (
mx.stack(to_join)
)
return weights
@property @property
def layers(self): def layers(self):
return self.model.layers return self.model.layers

View File

@ -5,7 +5,8 @@ from typing import Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np
from .switch_layers import SwitchMLP
@dataclass @dataclass
@ -75,17 +76,6 @@ class RoPEAttention(nn.Module):
return self.out_proj(output) return self.out_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.fc1 = nn.Linear(dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, dim)
self.act = nn.GELU(approx="precise")
def __call__(self, x) -> mx.array:
return self.fc2(self.act(self.fc1(x)))
class MOE(nn.Module): class MOE(nn.Module):
def __init__(self, args: ModelArgs, dim: int, hidden_dim: int): def __init__(self, args: ModelArgs, dim: int, hidden_dim: int):
super().__init__() super().__init__()
@ -93,40 +83,23 @@ class MOE(nn.Module):
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.num_experts = args.num_local_experts self.num_experts = args.num_local_experts
self.num_experts_per_tok = args.num_experts_per_tok self.num_experts_per_tok = args.num_experts_per_tok
self.mlp = [MLP(self.dim, self.hidden_dim) for _ in range(self.num_experts)] self.switch_mlp = SwitchMLP(
self.dim, self.hidden_dim, self.num_experts, bias=True
)
self.gate = nn.Linear(args.model_dim, self.num_experts, bias=False) self.gate = nn.Linear(args.model_dim, self.num_experts, bias=False)
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
ne = self.num_experts_per_tok
orig_shape = x.shape
x = x.reshape(-1, x.shape[-1])
gates = self.gate(x) gates = self.gate(x)
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne - 1, axis=-1))[:, :ne]
scores = mx.softmax(
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
axis=-1,
).astype(gates.dtype)
if self.training: k = self.num_experts_per_tok
ys = [] inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1))[..., :k]
y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype) scores = mx.take_along_axis(gates, inds, axis=-1)
for e, expert in enumerate(self.mlp): scores = mx.softmax(scores, axis=-1, precise=True)
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0:
continue
y[idx1, idx2] = expert(x[idx1])
y = (y * scores[..., None]).sum(axis=1) y = self.switch_mlp(x, inds)
else: y = (y * scores[..., None]).sum(axis=-2)
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.stack([self.mlp[e](xt) for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt[None, :])
y = mx.concatenate(y)
return y.reshape(orig_shape) return y
class ParallelBlock(nn.Module): class ParallelBlock(nn.Module):
@ -202,6 +175,21 @@ class Model(nn.Module):
y = self.transformer(x, mask, cache) y = self.transformer(x, mask, cache)
return self.lm_head(y) return self.lm_head(y)
def sanitize(self, weights):
if "transformer.h.0.moe.mlp.0.fc1.weight" not in weights:
return weights
for l in range(self.args.num_layers):
prefix = f"transformer.h.{l}"
for n in ["fc1", "fc2"]:
for k in ["weight", "scales", "biases", "bias"]:
to_join = [
weights.pop(f"{prefix}.moe.mlp.{e}.{n}.{k}")
for e in range(self.args.num_local_experts)
]
if to_join:
weights[f"{prefix}.moe.switch_mlp.{n}.{k}"] = mx.stack(to_join)
return weights
@property @property
def layers(self): def layers(self):
return self.transformer.h return self.transformer.h

View File

@ -1,11 +1,12 @@
import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs from .base import BaseModelArgs
from .switch_layers import SwitchGLU
@dataclass @dataclass
@ -92,7 +93,7 @@ class Attention(nn.Module):
return self.o_proj(output) return self.o_proj(output)
class Qwen2MoeMLP(nn.Module): class MLP(nn.Module):
def __init__(self, dim, hidden_dim): def __init__(self, dim, hidden_dim):
super().__init__() super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
@ -113,57 +114,32 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
self.num_experts = num_experts = args.num_experts self.num_experts = num_experts = args.num_experts
self.top_k = args.num_experts_per_tok self.top_k = args.num_experts_per_tok
# gating
self.gate = nn.Linear(dim, num_experts, bias=False) self.gate = nn.Linear(dim, num_experts, bias=False)
self.experts = [ self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts)
Qwen2MoeMLP(dim, intermediate_size) for _ in range(self.num_experts)
] self.shared_expert = MLP(dim, shared_expert_intermediate_size)
self.shared_expert = Qwen2MoeMLP(dim, shared_expert_intermediate_size)
self.shared_expert_gate = nn.Linear(dim, 1, bias=False) self.shared_expert_gate = nn.Linear(dim, 1, bias=False)
def __call__( def __call__(
self, self,
x: mx.array, x: mx.array,
): ):
ne = self.top_k
B, L, D = x.shape
x = x.reshape(-1, D)
# router_logits: (batch * sequence_length, n_experts)
gates = self.gate(x) gates = self.gate(x)
gates = mx.softmax(gates.astype(mx.float32), axis=-1) gates = mx.softmax(gates, axis=-1, precise=True)
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]) k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
scores = mx.take_along_axis(gates, inds, axis=-1)
scores = mx.take_along_axis(gates, inds, axis=-1).astype(x.dtype) y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
if self.training:
inds = np.array(inds)
y = mx.zeros((B * L, ne, D), x.dtype)
for e, expert in enumerate(self.experts):
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0:
continue
y[idx1, idx2] = expert(x[idx1])
y = (y * scores[:, :, None]).sum(axis=1)
else:
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.stack([self.experts[e](xt) for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt)
y = mx.stack(y, axis=0)
shared_expert_output = self.shared_expert(x) shared_expert_output = self.shared_expert(x)
shared_expert_output = ( shared_expert_output = (
mx.sigmoid(self.shared_expert_gate(x)) * shared_expert_output mx.sigmoid(self.shared_expert_gate(x)) * shared_expert_output
) )
y += shared_expert_output return y + shared_expert_output
return y.reshape(B, L, -1)
class Qwen2MoeDecoderLayer(nn.Module): class Qwen2MoeDecoderLayer(nn.Module):
@ -243,12 +219,19 @@ class Model(nn.Module):
return self.lm_head(out) return self.lm_head(out)
def sanitize(self, weights): def sanitize(self, weights):
if self.args.tie_word_embeddings and "lm_head.weight" not in weights: if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
weights["lm_head.weight"] = weights["model.embed_tokens.weight"] return weights
# Remove unused precomputed rotary freqs for l in range(self.args.num_hidden_layers):
return { prefix = f"model.layers.{l}"
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k for n in ["up_proj", "down_proj", "gate_proj"]:
} for k in ["weight", "scales", "biases"]:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
for e in range(self.args.num_experts)
]
if to_join:
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
return weights
@property @property
def layers(self): def layers(self):

View File

@ -0,0 +1,165 @@
import math
import mlx.core as mx
import mlx.nn as nn
class QuantizedSwitchLinear(nn.Module):
def __init__(
self,
input_dims: int,
output_dims: int,
num_experts: int,
bias: bool = True,
group_size: int = 64,
bits: int = 4,
):
super().__init__()
scale = math.sqrt(1 / input_dims)
self.weight, self.scales, self.biases = mx.quantize(
mx.random.uniform(
low=-scale,
high=scale,
shape=(num_experts, output_dims, input_dims),
),
group_size=group_size,
bits=bits,
)
if bias:
self.bias = mx.zeros((num_experts, output_dims))
self.group_size = group_size
self.bits = bits
# Freeze this model's parameters
self.freeze()
def unfreeze(self, *args, **kwargs):
"""Wrap unfreeze so that we unfreeze any layers we might contain but
our parameters will remain frozen."""
super().unfreeze(*args, **kwargs)
self.freeze(recurse=False)
@property
def input_dims(self):
return self.scales.shape[2] * self.group_size
@property
def output_dims(self):
return self.weight.shape[1]
@property
def num_experts(self):
return self.weight.shape[0]
def __call__(self, x, indices):
x = mx.block_sparse_qmm(
x,
self["weight"],
self["scales"],
self["biases"],
rhs_indices=indices,
transpose=True,
group_size=self.group_size,
bits=self.bits,
)
if "bias" in self:
x = x + mx.expand_dims(self["bias"][indices], -2)
return x
class SwitchLinear(nn.Module):
def __init__(
self, input_dims: int, output_dims: int, num_experts: int, bias: bool = True
):
super().__init__()
scale = math.sqrt(1 / input_dims)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(num_experts, output_dims, input_dims),
)
if bias:
self.bias = mx.zeros((num_experts, output_dims))
@property
def input_dims(self):
return self.weight.shape[2]
@property
def output_dims(self):
return self.weight.shape[1]
@property
def num_experts(self):
return self.weight.shape[0]
def __call__(self, x, indices):
x = mx.block_sparse_mm(x, self["weight"].swapaxes(-1, -2), rhs_indices=indices)
if "bias" in self:
x = x + mx.expand_dims(self["bias"][indices], -2)
return x
def to_quantized(self, group_size: int = 64, bits: int = 4):
num_experts, output_dims, input_dims = self.weight.shape
ql = QuantizedSwitchLinear(
input_dims, output_dims, num_experts, False, group_size, bits
)
ql.weight, ql.scales, ql.biases = mx.quantize(self.weight, group_size, bits)
if "bias" in self:
ql.bias = self.bias
return ql
class SwitchGLU(nn.Module):
def __init__(
self,
input_dims: int,
hidden_dims: int,
num_experts: int,
activation=nn.silu,
bias: bool = False,
):
super().__init__()
self.gate_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
self.up_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
self.down_proj = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias)
self.activation = activation
def __call__(self, x, indices) -> mx.array:
x = mx.expand_dims(x, (-2, -3))
x_up = self.up_proj(x, indices)
x_gate = self.gate_proj(x, indices)
x = self.down_proj(self.activation(x_gate) * x_up, indices)
return x.squeeze(-2)
class SwitchMLP(nn.Module):
def __init__(
self,
input_dims: int,
hidden_dims: int,
num_experts: int,
activation=nn.gelu_approx,
bias: bool = False,
):
super().__init__()
self.fc1 = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
self.fc2 = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias)
self.activation = activation
def __call__(self, x, indices) -> mx.array:
x = mx.expand_dims(x, (-2, -3))
x = self.fc1(x, indices)
x = self.activation(x)
x = self.fc2(x, indices)
return x.squeeze(-2)

View File

@ -5,6 +5,8 @@ import math
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
class LoRALinear(nn.Module): class LoRALinear(nn.Module):
@staticmethod @staticmethod
@ -100,3 +102,98 @@ class LoRALinear(nn.Module):
y = self.linear(x) y = self.linear(x)
z = (self.dropout(x) @ self.lora_a) @ self.lora_b z = (self.dropout(x) @ self.lora_a) @ self.lora_b
return y + (self.scale * z).astype(x.dtype) return y + (self.scale * z).astype(x.dtype)
class LoRASwitchLinear(nn.Module):
@staticmethod
def from_linear(
linear: nn.Module,
r: int = 8,
alpha: float = 16,
dropout: float = 0.0,
scale: float = 10.0,
):
lora_lin = LoRASwitchLinear(
input_dims=linear.input_dims,
output_dims=linear.output_dims,
num_experts=linear.num_experts,
r=r,
alpha=alpha,
dropout=dropout,
scale=scale,
)
lora_lin.linear = linear
return lora_lin
def to_linear(self, de_quantize: bool = False):
linear = self.linear
bias = "bias" in linear
weight = linear.weight
is_quantized = isinstance(linear, QuantizedSwitchLinear)
# Use the same type as the linear weight if not quantized
dtype = weight.dtype
if is_quantized:
dtype = mx.float16
weight = mx.dequantize(
weight,
linear.scales,
linear.biases,
linear.group_size,
linear.bits,
)
num_experts, output_dims, input_dims = weight.shape
fused_linear = SwitchLinear(input_dims, output_dims, num_experts, bias=bias)
lora_b = (self.scale * self.lora_b).astype(dtype)
lora_a = self.lora_a.reshape(num_experts, -1, input_dims).astype(dtype)
fused_linear.weight = weight + lora_b @ lora_a
if bias:
fused_linear.bias = linear.bias
if is_quantized and not de_quantize:
fused_linear = fused_linear.to_quantized(linear.group_size, linear.bits)
return fused_linear
def __init__(
self,
input_dims: int,
output_dims: int,
num_experts: int,
r: int = 8,
alpha: float = 16,
dropout: float = 0.0,
scale: float = 10.0,
bias: bool = False,
):
super().__init__()
# Regular linear layer weights
self.linear = SwitchLinear(input_dims, output_dims, num_experts, bias=bias)
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale * (alpha / r)
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
shape=(r * num_experts, input_dims),
)
self.lora_b = mx.zeros(shape=(num_experts, output_dims, r))
self.num_experts = num_experts
def __call__(self, x, indices):
shape = x.shape[:-3] + (self.num_experts, -1)
y = self.linear(x, indices)
z = (self.dropout(x) @ self.lora_a.T).reshape(shape)
z = mx.take_along_axis(z, indices[..., None], axis=-2)
z = z[..., None, :] @ self.lora_b[indices].swapaxes(-2, -1)
return y + (self.scale * z).astype(x.dtype)

View File

@ -9,8 +9,9 @@ import mlx.nn as nn
import mlx.optimizers as opt import mlx.optimizers as opt
from mlx.utils import tree_unflatten from mlx.utils import tree_unflatten
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
from .dora import DoRALinear from .dora import DoRALinear
from .lora import LoRALinear from .lora import LoRALinear, LoRASwitchLinear
def build_schedule(schedule_config: Dict): def build_schedule(schedule_config: Dict):
@ -58,11 +59,21 @@ def linear_to_lora_layers(
f"Requested {num_lora_layers} LoRA layers " f"Requested {num_lora_layers} LoRA layers "
f"but the model only has {num_layers} layers." f"but the model only has {num_layers} layers."
) )
cls = DoRALinear if use_dora else LoRALinear
def to_lora(lin): def to_lora(layer):
return cls.from_linear( if isinstance(layer, (nn.Linear, nn.QuantizedLinear)):
lin, LoRALayer = DoRALinear if use_dora else LoRALinear
elif isinstance(layer, (SwitchLinear, QuantizedSwitchLinear)):
if use_dora:
raise ValueError(f"{type(layer).__name__} doesn't support DoRA yet.")
LoRALayer = LoRASwitchLinear
else:
raise ValueError(
f"Can't convert layer of type {type(layer).__name__} to LoRA"
)
return LoRALayer.from_linear(
layer,
r=config["rank"], r=config["rank"],
alpha=config["alpha"], alpha=config["alpha"],
scale=config["scale"], scale=config["scale"],

View File

@ -366,10 +366,11 @@ def load_model(
if (quantization := config.get("quantization", None)) is not None: if (quantization := config.get("quantization", None)) is not None:
# Handle legacy models which may not have everything quantized # Handle legacy models which may not have everything quantized
class_predicate = ( def class_predicate(p, m):
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding)) if not hasattr(m, "to_quantized"):
and f"{p}.scales" in weights return False
) return f"{p}.scales" in weights
nn.quantize( nn.quantize(
model, model,
**quantization, **quantization,