From d661440dbb8e1970fadad79c5061e786fe1c54ca Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Tue, 2 Apr 2024 20:33:29 +0200 Subject: [PATCH] Add support for qwen2moe (#640) * add sparsemoe block and update decoder logic * update file name to match HF * update name * Code formatting * update gates calculation * add support for Qwen2MoE. * fix pytest * code formatting and fix missing comma in utils * Remove decoder sparse step. Co-authored-by: bozheng-hit * remove gate layer anti-quantisation * remove unused argument --------- Co-authored-by: bozheng-hit --- llms/mlx_lm/models/qwen2_moe.py | 257 ++++++++++++++++++++++++++++++++ llms/mlx_lm/tuner/utils.py | 4 + llms/tests/test_models.py | 21 +++ 3 files changed, 282 insertions(+) create mode 100644 llms/mlx_lm/models/qwen2_moe.py diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py new file mode 100644 index 00000000..536d2e1b --- /dev/null +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -0,0 +1,257 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +from .base import BaseModelArgs + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + num_experts_per_tok: int + num_experts: int + moe_intermediate_size: int + shared_expert_intermediate_size: int + rms_norm_eps: float + vocab_size: int + num_key_value_heads: int = None + rope_theta: float = 1000000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = False + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + if self.rope_scaling: + required_keys = {"factor", "type"} + if not all(key in self.rope_scaling for key in required_keys): + raise ValueError(f"rope_scaling must contain keys {required_keys}") + + if self.rope_scaling["type"] != "linear": + raise ValueError("rope_scaling 'type' currently only supports 'linear'") + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + head_dim = args.hidden_size // n_heads + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + + self.rope = nn.RoPE( + head_dim, + traditional=args.rope_traditional, + base=args.rope_theta, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output), (keys, values) + + +class Qwen2MoeMLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class Qwen2MoeSparseMoeBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + dim = args.hidden_size + intermediate_size = args.moe_intermediate_size + shared_expert_intermediate_size = args.shared_expert_intermediate_size + + self.num_experts = num_experts = args.num_experts + self.top_k = args.num_experts_per_tok + + # gating + self.gate = nn.Linear(dim, num_experts, bias=False) + self.experts = [ + Qwen2MoeMLP(dim, intermediate_size) for _ in range(self.num_experts) + ] + self.shared_expert = Qwen2MoeMLP(dim, shared_expert_intermediate_size) + self.shared_expert_gate = nn.Linear(dim, 1, bias=False) + + def __call__( + self, + x: mx.array, + ): + ne = self.top_k + B, L, D = x.shape + x = x.reshape(-1, D) + + # router_logits: (batch * sequence_length, n_experts) + gates = self.gate(x) + gates = mx.softmax(gates.astype(mx.float32), axis=-1) + + inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]) + + scores = mx.take_along_axis(gates, inds, axis=-1).astype(x.dtype) + + if self.training: + inds = np.array(inds) + y = mx.zeros((B, ne, D), x.dtype) + for e, expert in enumerate(self.experts): + idx1, idx2 = map(mx.array, np.where(inds == e)) + if idx1.size == 0: + continue + y[idx1, idx2] = expert(x[idx1]) + + y = (y * scores[:, :, None]).sum(axis=1) + else: + y = [] + for xt, st, it in zip(x, scores, inds.tolist()): + yt = mx.stack([self.experts[e](xt) for e in it], axis=-1) + yt = (yt * st).sum(axis=-1) + y.append(yt) + + y = mx.stack(y, axis=0) + + shared_expert_output = self.shared_expert(x) + shared_expert_output = ( + mx.sigmoid(self.shared_expert_gate(x)) * shared_expert_output + ) + + y += shared_expert_output + + return y.reshape(B, L, -1) + + +class Qwen2MoeDecoderLayer(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.hidden_size = args.hidden_size + self.self_attn = Attention(args) + self.mlp = Qwen2MoeSparseMoeBlock(args) + + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out, cache + + +class Qwen2MoeModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + Qwen2MoeDecoderLayer(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.embed_tokens(inputs) + + mask = None + if h.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None] * len(self.layers) + + for e, layer in enumerate(self.layers): + h, cache[e] = layer(h, mask, cache[e]) + + return self.norm(h), cache + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = Qwen2MoeModel(args) + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out, cache = self.model(inputs, cache) + return self.lm_head(out), cache + + def sanitize(self, weights): + if self.args.tie_word_embeddings and "lm_head.weight" not in weights: + weights["lm_head.weight"] = weights["model.embed_tokens.weight"] + # Remove unused precomputed rotary freqs + return { + k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k + } + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 91990d84..40d42ee4 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -70,6 +70,7 @@ def linear_to_lora_layers( "mixtral", "stablelm", "qwen2", + "qwen2_moe", "gemma", "starcoder2", "cohere", @@ -77,6 +78,9 @@ def linear_to_lora_layers( keys = set(["self_attn.q_proj", "self_attn.v_proj"]) if model.model_type == "mixtral": keys.add("block_sparse_moe.gate") + if model.model_type == "qwen2_moe": + keys.add("mlp.gate") + keys.add("mlp.shared_expert_gate") elif model.model_type == "olmo": keys = set(["att_proj"]) elif model.model_type == "phi-msft": diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 865d419d..effeab53 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -114,6 +114,27 @@ class TestModels(unittest.TestCase): args.n_layers, ) + def test_qwen2_moe(self): + from mlx_lm.models import qwen2_moe + + args = qwen2_moe.ModelArgs( + model_type="qwen2_moe", + hidden_size=1024, + num_hidden_layers=4, + intermediate_size=2048, + num_attention_heads=4, + rms_norm_eps=1e-5, + vocab_size=10_000, + num_experts_per_tok=4, + num_experts=16, + moe_intermediate_size=1024, + shared_expert_intermediate_size=2048, + ) + model = qwen2_moe.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + def test_qwen2(self): from mlx_lm.models import qwen2