From 561dcf5643ecced3be9bf053499e0e2ea51e8686 Mon Sep 17 00:00:00 2001 From: Anchen Date: Thu, 18 Jul 2024 00:23:28 +1000 Subject: [PATCH] Add support for deepseek coder v2 lite (#882) * feat: add support for deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct * fix softmax + some cleanup * more nits * fix rope * fix original_max_position_embeddings in rope * fix original_max_position_embeddings in rope config * add group greedy --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/models/base.py | 14 +- llms/mlx_lm/models/deepseek_v2.py | 468 ++++++++++++++++++++++++++++++ 2 files changed, 478 insertions(+), 4 deletions(-) create mode 100644 llms/mlx_lm/models/deepseek_v2.py diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index b98a1909..8c3ecc78 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -15,7 +15,12 @@ class KVCache: def __init__(self, head_dim, n_kv_heads): self.n_kv_heads = n_kv_heads - self.head_dim = head_dim + if isinstance(head_dim, int): + self.k_head_dim = self.v_head_dim = head_dim + elif isinstance(head_dim, tuple) and len(head_dim) == 2: + self.k_head_dim, self.v_head_dim = head_dim + else: + raise ValueError("head_dim must be an int or a tuple of two ints") self.keys = None self.values = None self.offset = 0 @@ -25,9 +30,10 @@ class KVCache: prev = self.offset if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: n_steps = (self.step + keys.shape[2] - 1) // self.step - shape = (1, self.n_kv_heads, n_steps * self.step, self.head_dim) - new_k = mx.zeros(shape, keys.dtype) - new_v = mx.zeros(shape, values.dtype) + k_shape = (1, self.n_kv_heads, n_steps * self.step, self.k_head_dim) + v_shape = (1, self.n_kv_heads, n_steps * self.step, self.v_head_dim) + new_k = mx.zeros(k_shape, keys.dtype) + new_v = mx.zeros(v_shape, values.dtype) if self.keys is not None: if prev % self.step != 0: self.keys = self.keys[..., :prev, :] diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py new file mode 100644 index 00000000..308b94ba --- /dev/null +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -0,0 +1,468 @@ +import math +from dataclasses import dataclass +from typing import Dict, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, KVCache +from .switch_layers import SwitchGLU + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str = "deepseek_v2" + vocab_size: int = 102400 + hidden_size: int = 4096 + intermediate_size: int = 11008 + moe_intermediate_size: int = 1407 + num_hidden_layers: int = 30 + num_attention_heads: int = 32 + num_key_value_heads: int = 32 + n_shared_experts: Optional[int] = None + n_routed_experts: Optional[int] = None + routed_scaling_factor: float = 1.0 + kv_lora_rank: int = 512 + q_lora_rank: int = 1536 + qk_rope_head_dim: int = 64 + v_head_dim: int = 128 + qk_nope_head_dim: int = 128 + topk_method: str = "gready" + n_group: Optional[int] = None + topk_group: Optional[int] = None + num_experts_per_tok: Optional[int] = None + moe_layer_freq: int = 1 + first_k_dense_replace: int = 0 + max_position_embeddings: int = 2048 + rms_norm_eps: float = 1e-6 + rope_theta: float = 10000.0 + rope_scaling: Optional[Dict] = None + attention_bias: bool = False + + +def yarn_find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 +): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +def yarn_find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (mx.arange(dim, dtype=mx.float32) - min) / (max - min) + ramp_func = mx.clip(linear_func, 0, 1) + return ramp_func + + +class DeepseekV2YarnRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + + self.max_seq_len_cached = None + self._cos_cached = None + self._sin_cached = None + self._inv_freq = None + self.set_cos_sin_cache(max_position_embeddings) + + def set_cos_sin_cache(self, seq_len): + self.max_seq_len_cached = seq_len + dim = self.dim + freq_extra = 1.0 / (self.base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)) + freq_inter = 1.0 / ( + self.scaling_factor + * self.base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self._inv_freq = inv_freq + + t = mx.arange(seq_len, dtype=mx.float32) + freqs = mx.outer(t, inv_freq) + + mscale = yarn_get_mscale(self.scaling_factor, self.mscale) / yarn_get_mscale( + self.scaling_factor, self.mscale_all_dim + ) + + self._cos_cached = mx.cos(freqs) * mscale + self._sin_cached = mx.sin(freqs) * mscale + + def apply_rotary_pos_emb(self, x, cos, sin): + x1 = x[..., ::2] + x2 = x[..., 1::2] + rx1 = x1 * cos - x2 * sin + rx2 = x1 * sin + x2 * cos + return mx.concatenate([rx1, rx2], axis=-1) + + def __call__(self, x, offset=0): + seq_len = offset + x.shape[2] + if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: + self.set_cos_sin_cache(seq_len=seq_len) + + if self._cos_cached.dtype != x.dtype: + self._cos_cached = self._cos_cached.astype(x.dtype) + self._sin_cached = self._sin_cached.astype(x.dtype) + + return self.apply_rotary_pos_emb( + x, + self._cos_cached[offset:seq_len], + self._sin_cached[offset:seq_len], + ) + + +class DeepseekV2Attention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.scale = self.q_head_dim**-0.5 + + if self.q_lora_rank is None: + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.q_head_dim, bias=False + ) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, self.q_lora_rank, bias=config.attention_bias + ) + self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank) + self.q_b_proj = nn.Linear( + self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False + ) + + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank) + self.kv_b_proj = nn.Linear( + self.kv_lora_rank, + self.num_heads + * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + ) + + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.scale = self.scale * mscale * mscale + + rope_kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rope = DeepseekV2YarnRotaryEmbedding( + dim=self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **rope_kwargs, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[KVCache] = None, + ) -> mx.array: + B, L, D = x.shape + + if self.q_lora_rank is None: + q = self.q_proj(x) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x))) + + q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3) + q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1) + compressed_kv = self.kv_a_proj_with_mqa(x) + compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1) + k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3) + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3) + + k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1) + + k_pe = mx.concatenate([k_pe] * self.num_heads, axis=1) + + if cache is not None: + q_pe = self.rope(q_pe, cache.offset) + k_pe = self.rope(k_pe, cache.offset) + keys, values = cache.update_and_fetch( + mx.concatenate([k_nope, k_pe], axis=-1), values + ) + else: + q_pe = self.rope(q_pe) + k_pe = self.rope(k_pe) + keys = mx.concatenate([k_nope, k_pe], axis=-1) + + queries = mx.concatenate([q_nope, q_pe], axis=-1) + + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class DeepseekV2MLP(nn.Module): + def __init__( + self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = ( + config.intermediate_size if intermediate_size is None else intermediate_size + ) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + def __call__(self, x): + down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class MoEGate(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + self.weight = mx.zeros((self.n_routed_experts, config.hidden_size)) + + def __call__(self, x): + gates = x @ self.weight.T + + scores = mx.softmax(gates, axis=-1, precise=True) + + if self.topk_method == "group_limited_greedy": + bsz, seq_len = x.shape[:2] + scores = scores.reshape(bsz, seq_len, self.n_group, -1) + group_scores = scores.max(axis=-1) + k = self.n_group - self.topk_group + group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k] + batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2)) + seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2)) + scores[batch_idx, seq_idx, group_idx] = 0.0 + scores = scores.reshape(bsz, seq_len, -1) + + k = self.top_k + inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]) + scores = mx.take_along_axis(scores, inds, axis=-1) + scores = scores * self.routed_scaling_factor + + return inds, scores + + +class DeepseekV2MoE(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + self.switch_mlp = SwitchGLU( + config.hidden_size, config.moe_intermediate_size, config.n_routed_experts + ) + + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekV2MLP( + config=config, intermediate_size=intermediate_size + ) + + def __call__(self, x): + inds, scores = self.gate(x) + y = self.switch_mlp(x, inds) + y = (y * scores[..., None]).sum(axis=-2) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(x) + + return y + + +class DeepseekV2DecoderLayer(nn.Module): + def __init__(self, config: ModelArgs, layer_idx: int): + super().__init__() + self.self_attn = DeepseekV2Attention(config) + self.mlp = ( + DeepseekV2MoE(config) + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + else DeepseekV2MLP(config) + ) + self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[KVCache] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out + + +class DeepseekV2Model(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = [ + DeepseekV2DecoderLayer(config, idx) + for idx in range(config.num_hidden_layers) + ] + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def __call__( + self, + x: mx.array, + cache: Optional[KVCache] = None, + ) -> mx.array: + h = self.embed_tokens(x) + mask = None + T = h.shape[1] + if T > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(T) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.args = config + self.model_type = config.model_type + self.model = DeepseekV2Model(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache: Optional[KVCache] = None, + ): + out = self.model(inputs, cache) + return self.lm_head(out) + + def sanitize(self, weights): + for l in range(self.args.num_hidden_layers): + prefix = f"model.layers.{l}" + for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]: + for k in ["weight", "scales", "biases"]: + if f"{prefix}.mlp.experts.0.{m}.{k}" in weights: + to_join = [ + weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") + for e in range(self.args.n_routed_experts) + ] + weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join) + return weights + + @property + def layers(self): + return self.model.layers + + @property + def head_dim(self): + return ( + self.args.qk_nope_head_dim + self.args.qk_rope_head_dim, + self.args.v_head_dim, + ) + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads