From 9f671228cdf9c33de41f59a9b488e1f9cdfd33a3 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 21 May 2024 15:58:08 -0700 Subject: [PATCH] Block sparse MM MoEs (#782) - Adds SwitchLinear - Adds QuantizedSwitchLinear --- llms/mlx_lm/fuse.py | 4 +- llms/mlx_lm/models/mixtral.py | 77 +++++-------- llms/mlx_lm/models/phixtral.py | 66 +++++------ llms/mlx_lm/models/qwen2_moe.py | 69 +++++------- llms/mlx_lm/models/switch_layers.py | 165 ++++++++++++++++++++++++++++ llms/mlx_lm/tuner/lora.py | 97 ++++++++++++++++ llms/mlx_lm/tuner/utils.py | 21 +++- llms/mlx_lm/utils.py | 9 +- 8 files changed, 365 insertions(+), 143 deletions(-) create mode 100644 llms/mlx_lm/models/switch_layers.py diff --git a/llms/mlx_lm/fuse.py b/llms/mlx_lm/fuse.py index 1c7250e7..fa06eb54 100644 --- a/llms/mlx_lm/fuse.py +++ b/llms/mlx_lm/fuse.py @@ -7,7 +7,7 @@ from mlx.utils import tree_flatten, tree_unflatten from .gguf import convert_to_gguf from .tuner.dora import DoRALinear -from .tuner.lora import LoRALinear +from .tuner.lora import LoRALinear, LoRASwitchLinear from .tuner.utils import apply_lora_layers, dequantize from .utils import ( fetch_from_hub, @@ -82,7 +82,7 @@ def main() -> None: fused_linears = [ (n, m.to_linear()) for n, m in model.named_modules() - if isinstance(m, (LoRALinear, DoRALinear)) + if isinstance(m, (LoRASwitchLinear, LoRALinear, DoRALinear)) ] model.update_modules(tree_unflatten(fused_linears)) diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index 7bf67638..ee401e8e 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -1,11 +1,12 @@ +import math from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -import numpy as np from .base import BaseModelArgs +from .switch_layers import SwitchGLU @dataclass @@ -91,24 +92,6 @@ class MixtralAttention(nn.Module): return self.o_proj(output) -class MixtralBLockSparseTop2MLP(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.ffn_dim = args.intermediate_size - self.hidden_dim = args.hidden_size - - self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) - self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - - self.act_fn = nn.silu - - def __call__(self, x: mx.array) -> mx.array: - current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x) - current_hidden_states = self.w2(current_hidden_states) - return current_hidden_states - - class MixtralSparseMoeBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -120,43 +103,20 @@ class MixtralSparseMoeBlock(nn.Module): # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - self.experts = [ - MixtralBLockSparseTop2MLP(args=args) for _ in range(self.num_experts) - ] + self.switch_mlp = SwitchGLU(self.hidden_dim, self.ffn_dim, self.num_experts) def __call__(self, x: mx.array) -> mx.array: - ne = self.num_experts_per_tok - orig_shape = x.shape - x = x.reshape(-1, x.shape[-1]) - gates = self.gate(x) - inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne]) + k = self.num_experts_per_tok + inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k]) + scores = mx.take_along_axis(gates, inds, axis=-1) + scores = mx.softmax(scores, axis=-1, precise=True) - scores = mx.softmax( - mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32), - axis=-1, - ).astype(gates.dtype) + y = self.switch_mlp(x, inds) + y = (y * scores[..., None]).sum(axis=-2) - if self.training: - inds = np.array(inds) - y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype) - for e, expert in enumerate(self.experts): - idx1, idx2 = map(mx.array, np.where(inds == e)) - if idx1.size == 0: - continue - y[idx1, idx2] = expert(x[idx1]) - - y = (y * scores[:, :, None]).sum(axis=1) - else: - y = [] - for xt, st, it in zip(x, scores, inds.tolist()): - yt = mx.stack([self.experts[e](xt) for e in it], axis=-1) - yt = (yt * st).sum(axis=-1) - y.append(yt[None, :]) - y = mx.concatenate(y) - - return y.reshape(orig_shape) + return y class MixtralDecoderLayer(nn.Module): @@ -235,6 +195,23 @@ class Model(nn.Module): out = self.model(inputs, cache) return self.lm_head(out) + def sanitize(self, weights): + if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights: + return weights + for l in range(self.args.num_hidden_layers): + prefix = f"model.layers.{l}" + for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]: + for k in ["weight", "scales", "biases"]: + to_join = [ + weights.pop(f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}") + for e in range(self.args.num_local_experts) + ] + if to_join: + weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = ( + mx.stack(to_join) + ) + return weights + @property def layers(self): return self.model.layers diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index 7413e3cd..ded56c68 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -5,7 +5,8 @@ from typing import Tuple import mlx.core as mx import mlx.nn as nn -import numpy as np + +from .switch_layers import SwitchMLP @dataclass @@ -75,17 +76,6 @@ class RoPEAttention(nn.Module): return self.out_proj(output) -class MLP(nn.Module): - def __init__(self, dim, hidden_dim): - super().__init__() - self.fc1 = nn.Linear(dim, hidden_dim) - self.fc2 = nn.Linear(hidden_dim, dim) - self.act = nn.GELU(approx="precise") - - def __call__(self, x) -> mx.array: - return self.fc2(self.act(self.fc1(x))) - - class MOE(nn.Module): def __init__(self, args: ModelArgs, dim: int, hidden_dim: int): super().__init__() @@ -93,40 +83,23 @@ class MOE(nn.Module): self.hidden_dim = hidden_dim self.num_experts = args.num_local_experts self.num_experts_per_tok = args.num_experts_per_tok - self.mlp = [MLP(self.dim, self.hidden_dim) for _ in range(self.num_experts)] + self.switch_mlp = SwitchMLP( + self.dim, self.hidden_dim, self.num_experts, bias=True + ) self.gate = nn.Linear(args.model_dim, self.num_experts, bias=False) def __call__(self, x: mx.array) -> mx.array: - ne = self.num_experts_per_tok - orig_shape = x.shape - x = x.reshape(-1, x.shape[-1]) - gates = self.gate(x) - inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne - 1, axis=-1))[:, :ne] - scores = mx.softmax( - mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32), - axis=-1, - ).astype(gates.dtype) - if self.training: - ys = [] - y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype) - for e, expert in enumerate(self.mlp): - idx1, idx2 = map(mx.array, np.where(inds == e)) - if idx1.size == 0: - continue - y[idx1, idx2] = expert(x[idx1]) + k = self.num_experts_per_tok + inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1))[..., :k] + scores = mx.take_along_axis(gates, inds, axis=-1) + scores = mx.softmax(scores, axis=-1, precise=True) - y = (y * scores[..., None]).sum(axis=1) - else: - y = [] - for xt, st, it in zip(x, scores, inds.tolist()): - yt = mx.stack([self.mlp[e](xt) for e in it], axis=-1) - yt = (yt * st).sum(axis=-1) - y.append(yt[None, :]) - y = mx.concatenate(y) + y = self.switch_mlp(x, inds) + y = (y * scores[..., None]).sum(axis=-2) - return y.reshape(orig_shape) + return y class ParallelBlock(nn.Module): @@ -202,6 +175,21 @@ class Model(nn.Module): y = self.transformer(x, mask, cache) return self.lm_head(y) + def sanitize(self, weights): + if "transformer.h.0.moe.mlp.0.fc1.weight" not in weights: + return weights + for l in range(self.args.num_layers): + prefix = f"transformer.h.{l}" + for n in ["fc1", "fc2"]: + for k in ["weight", "scales", "biases", "bias"]: + to_join = [ + weights.pop(f"{prefix}.moe.mlp.{e}.{n}.{k}") + for e in range(self.args.num_local_experts) + ] + if to_join: + weights[f"{prefix}.moe.switch_mlp.{n}.{k}"] = mx.stack(to_join) + return weights + @property def layers(self): return self.transformer.h diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index ea8ab802..1bd065aa 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -1,11 +1,12 @@ +import math from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -import numpy as np from .base import BaseModelArgs +from .switch_layers import SwitchGLU @dataclass @@ -92,7 +93,7 @@ class Attention(nn.Module): return self.o_proj(output) -class Qwen2MoeMLP(nn.Module): +class MLP(nn.Module): def __init__(self, dim, hidden_dim): super().__init__() self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) @@ -113,57 +114,32 @@ class Qwen2MoeSparseMoeBlock(nn.Module): self.num_experts = num_experts = args.num_experts self.top_k = args.num_experts_per_tok - # gating self.gate = nn.Linear(dim, num_experts, bias=False) - self.experts = [ - Qwen2MoeMLP(dim, intermediate_size) for _ in range(self.num_experts) - ] - self.shared_expert = Qwen2MoeMLP(dim, shared_expert_intermediate_size) + self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts) + + self.shared_expert = MLP(dim, shared_expert_intermediate_size) self.shared_expert_gate = nn.Linear(dim, 1, bias=False) def __call__( self, x: mx.array, ): - ne = self.top_k - B, L, D = x.shape - x = x.reshape(-1, D) - - # router_logits: (batch * sequence_length, n_experts) gates = self.gate(x) - gates = mx.softmax(gates.astype(mx.float32), axis=-1) + gates = mx.softmax(gates, axis=-1, precise=True) - inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]) + k = self.top_k + inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k]) + scores = mx.take_along_axis(gates, inds, axis=-1) - scores = mx.take_along_axis(gates, inds, axis=-1).astype(x.dtype) - - if self.training: - inds = np.array(inds) - y = mx.zeros((B * L, ne, D), x.dtype) - for e, expert in enumerate(self.experts): - idx1, idx2 = map(mx.array, np.where(inds == e)) - if idx1.size == 0: - continue - y[idx1, idx2] = expert(x[idx1]) - - y = (y * scores[:, :, None]).sum(axis=1) - else: - y = [] - for xt, st, it in zip(x, scores, inds.tolist()): - yt = mx.stack([self.experts[e](xt) for e in it], axis=-1) - yt = (yt * st).sum(axis=-1) - y.append(yt) - - y = mx.stack(y, axis=0) + y = self.switch_mlp(x, inds) + y = (y * scores[..., None]).sum(axis=-2) shared_expert_output = self.shared_expert(x) shared_expert_output = ( mx.sigmoid(self.shared_expert_gate(x)) * shared_expert_output ) - y += shared_expert_output - - return y.reshape(B, L, -1) + return y + shared_expert_output class Qwen2MoeDecoderLayer(nn.Module): @@ -243,12 +219,19 @@ class Model(nn.Module): return self.lm_head(out) def sanitize(self, weights): - if self.args.tie_word_embeddings and "lm_head.weight" not in weights: - weights["lm_head.weight"] = weights["model.embed_tokens.weight"] - # Remove unused precomputed rotary freqs - return { - k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k - } + if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: + return weights + for l in range(self.args.num_hidden_layers): + prefix = f"model.layers.{l}" + for n in ["up_proj", "down_proj", "gate_proj"]: + for k in ["weight", "scales", "biases"]: + to_join = [ + weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}") + for e in range(self.args.num_experts) + ] + if to_join: + weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join) + return weights @property def layers(self): diff --git a/llms/mlx_lm/models/switch_layers.py b/llms/mlx_lm/models/switch_layers.py new file mode 100644 index 00000000..cad99ec0 --- /dev/null +++ b/llms/mlx_lm/models/switch_layers.py @@ -0,0 +1,165 @@ +import math + +import mlx.core as mx +import mlx.nn as nn + + +class QuantizedSwitchLinear(nn.Module): + def __init__( + self, + input_dims: int, + output_dims: int, + num_experts: int, + bias: bool = True, + group_size: int = 64, + bits: int = 4, + ): + super().__init__() + + scale = math.sqrt(1 / input_dims) + self.weight, self.scales, self.biases = mx.quantize( + mx.random.uniform( + low=-scale, + high=scale, + shape=(num_experts, output_dims, input_dims), + ), + group_size=group_size, + bits=bits, + ) + + if bias: + self.bias = mx.zeros((num_experts, output_dims)) + + self.group_size = group_size + self.bits = bits + + # Freeze this model's parameters + self.freeze() + + def unfreeze(self, *args, **kwargs): + """Wrap unfreeze so that we unfreeze any layers we might contain but + our parameters will remain frozen.""" + super().unfreeze(*args, **kwargs) + self.freeze(recurse=False) + + @property + def input_dims(self): + return self.scales.shape[2] * self.group_size + + @property + def output_dims(self): + return self.weight.shape[1] + + @property + def num_experts(self): + return self.weight.shape[0] + + def __call__(self, x, indices): + x = mx.block_sparse_qmm( + x, + self["weight"], + self["scales"], + self["biases"], + rhs_indices=indices, + transpose=True, + group_size=self.group_size, + bits=self.bits, + ) + if "bias" in self: + x = x + mx.expand_dims(self["bias"][indices], -2) + return x + + +class SwitchLinear(nn.Module): + def __init__( + self, input_dims: int, output_dims: int, num_experts: int, bias: bool = True + ): + super().__init__() + scale = math.sqrt(1 / input_dims) + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(num_experts, output_dims, input_dims), + ) + + if bias: + self.bias = mx.zeros((num_experts, output_dims)) + + @property + def input_dims(self): + return self.weight.shape[2] + + @property + def output_dims(self): + return self.weight.shape[1] + + @property + def num_experts(self): + return self.weight.shape[0] + + def __call__(self, x, indices): + x = mx.block_sparse_mm(x, self["weight"].swapaxes(-1, -2), rhs_indices=indices) + if "bias" in self: + x = x + mx.expand_dims(self["bias"][indices], -2) + return x + + def to_quantized(self, group_size: int = 64, bits: int = 4): + num_experts, output_dims, input_dims = self.weight.shape + ql = QuantizedSwitchLinear( + input_dims, output_dims, num_experts, False, group_size, bits + ) + ql.weight, ql.scales, ql.biases = mx.quantize(self.weight, group_size, bits) + if "bias" in self: + ql.bias = self.bias + return ql + + +class SwitchGLU(nn.Module): + def __init__( + self, + input_dims: int, + hidden_dims: int, + num_experts: int, + activation=nn.silu, + bias: bool = False, + ): + super().__init__() + + self.gate_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias) + self.up_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias) + self.down_proj = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias) + self.activation = activation + + def __call__(self, x, indices) -> mx.array: + x = mx.expand_dims(x, (-2, -3)) + + x_up = self.up_proj(x, indices) + x_gate = self.gate_proj(x, indices) + x = self.down_proj(self.activation(x_gate) * x_up, indices) + + return x.squeeze(-2) + + +class SwitchMLP(nn.Module): + def __init__( + self, + input_dims: int, + hidden_dims: int, + num_experts: int, + activation=nn.gelu_approx, + bias: bool = False, + ): + super().__init__() + + self.fc1 = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias) + self.fc2 = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias) + self.activation = activation + + def __call__(self, x, indices) -> mx.array: + x = mx.expand_dims(x, (-2, -3)) + + x = self.fc1(x, indices) + x = self.activation(x) + x = self.fc2(x, indices) + + return x.squeeze(-2) diff --git a/llms/mlx_lm/tuner/lora.py b/llms/mlx_lm/tuner/lora.py index 76894509..22b0b4d0 100644 --- a/llms/mlx_lm/tuner/lora.py +++ b/llms/mlx_lm/tuner/lora.py @@ -5,6 +5,8 @@ import math import mlx.core as mx import mlx.nn as nn +from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear + class LoRALinear(nn.Module): @staticmethod @@ -100,3 +102,98 @@ class LoRALinear(nn.Module): y = self.linear(x) z = (self.dropout(x) @ self.lora_a) @ self.lora_b return y + (self.scale * z).astype(x.dtype) + + +class LoRASwitchLinear(nn.Module): + @staticmethod + def from_linear( + linear: nn.Module, + r: int = 8, + alpha: float = 16, + dropout: float = 0.0, + scale: float = 10.0, + ): + lora_lin = LoRASwitchLinear( + input_dims=linear.input_dims, + output_dims=linear.output_dims, + num_experts=linear.num_experts, + r=r, + alpha=alpha, + dropout=dropout, + scale=scale, + ) + lora_lin.linear = linear + return lora_lin + + def to_linear(self, de_quantize: bool = False): + linear = self.linear + bias = "bias" in linear + weight = linear.weight + is_quantized = isinstance(linear, QuantizedSwitchLinear) + + # Use the same type as the linear weight if not quantized + dtype = weight.dtype + + if is_quantized: + dtype = mx.float16 + weight = mx.dequantize( + weight, + linear.scales, + linear.biases, + linear.group_size, + linear.bits, + ) + num_experts, output_dims, input_dims = weight.shape + fused_linear = SwitchLinear(input_dims, output_dims, num_experts, bias=bias) + + lora_b = (self.scale * self.lora_b).astype(dtype) + lora_a = self.lora_a.reshape(num_experts, -1, input_dims).astype(dtype) + fused_linear.weight = weight + lora_b @ lora_a + if bias: + fused_linear.bias = linear.bias + + if is_quantized and not de_quantize: + fused_linear = fused_linear.to_quantized(linear.group_size, linear.bits) + + return fused_linear + + def __init__( + self, + input_dims: int, + output_dims: int, + num_experts: int, + r: int = 8, + alpha: float = 16, + dropout: float = 0.0, + scale: float = 10.0, + bias: bool = False, + ): + super().__init__() + + # Regular linear layer weights + self.linear = SwitchLinear(input_dims, output_dims, num_experts, bias=bias) + + self.dropout = nn.Dropout(p=dropout) + + # Scale for low-rank update + self.scale = scale * (alpha / r) + + # Low rank lora weights + scale = 1 / math.sqrt(input_dims) + self.lora_a = mx.random.uniform( + low=-scale, + high=scale, + shape=(r * num_experts, input_dims), + ) + self.lora_b = mx.zeros(shape=(num_experts, output_dims, r)) + self.num_experts = num_experts + + def __call__(self, x, indices): + shape = x.shape[:-3] + (self.num_experts, -1) + + y = self.linear(x, indices) + z = (self.dropout(x) @ self.lora_a.T).reshape(shape) + z = mx.take_along_axis(z, indices[..., None], axis=-2) + z = z[..., None, :] @ self.lora_b[indices].swapaxes(-2, -1) + + return y + (self.scale * z).astype(x.dtype) diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index e976e4af..03f782a1 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -9,8 +9,9 @@ import mlx.nn as nn import mlx.optimizers as opt from mlx.utils import tree_unflatten +from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear from .dora import DoRALinear -from .lora import LoRALinear +from .lora import LoRALinear, LoRASwitchLinear def build_schedule(schedule_config: Dict): @@ -58,11 +59,21 @@ def linear_to_lora_layers( f"Requested {num_lora_layers} LoRA layers " f"but the model only has {num_layers} layers." ) - cls = DoRALinear if use_dora else LoRALinear - def to_lora(lin): - return cls.from_linear( - lin, + def to_lora(layer): + if isinstance(layer, (nn.Linear, nn.QuantizedLinear)): + LoRALayer = DoRALinear if use_dora else LoRALinear + elif isinstance(layer, (SwitchLinear, QuantizedSwitchLinear)): + if use_dora: + raise ValueError(f"{type(layer).__name__} doesn't support DoRA yet.") + LoRALayer = LoRASwitchLinear + else: + raise ValueError( + f"Can't convert layer of type {type(layer).__name__} to LoRA" + ) + + return LoRALayer.from_linear( + layer, r=config["rank"], alpha=config["alpha"], scale=config["scale"], diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 11653572..d665325e 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -366,10 +366,11 @@ def load_model( if (quantization := config.get("quantization", None)) is not None: # Handle legacy models which may not have everything quantized - class_predicate = ( - lambda p, m: isinstance(m, (nn.Linear, nn.Embedding)) - and f"{p}.scales" in weights - ) + def class_predicate(p, m): + if not hasattr(m, "to_quantized"): + return False + return f"{p}.scales" in weights + nn.quantize( model, **quantization,