From 561dcf5643ecced3be9bf053499e0e2ea51e8686 Mon Sep 17 00:00:00 2001
From: Anchen
Date: Thu, 18 Jul 2024 00:23:28 +1000
Subject: [PATCH] Add support for deepseek coder v2 lite (#882)
* feat: add support for deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct
* fix softmax + some cleanup
* more nits
* fix rope
* fix original_max_position_embeddings in rope
* fix original_max_position_embeddings in rope config
* add group greedy
---------
Co-authored-by: Awni Hannun
---
llms/mlx_lm/models/base.py | 14 +-
llms/mlx_lm/models/deepseek_v2.py | 468 ++++++++++++++++++++++++++++++
2 files changed, 478 insertions(+), 4 deletions(-)
create mode 100644 llms/mlx_lm/models/deepseek_v2.py
diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py
index b98a1909..8c3ecc78 100644
--- a/llms/mlx_lm/models/base.py
+++ b/llms/mlx_lm/models/base.py
@@ -15,7 +15,12 @@ class KVCache:
def __init__(self, head_dim, n_kv_heads):
self.n_kv_heads = n_kv_heads
- self.head_dim = head_dim
+ if isinstance(head_dim, int):
+ self.k_head_dim = self.v_head_dim = head_dim
+ elif isinstance(head_dim, tuple) and len(head_dim) == 2:
+ self.k_head_dim, self.v_head_dim = head_dim
+ else:
+ raise ValueError("head_dim must be an int or a tuple of two ints")
self.keys = None
self.values = None
self.offset = 0
@@ -25,9 +30,10 @@ class KVCache:
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
n_steps = (self.step + keys.shape[2] - 1) // self.step
- shape = (1, self.n_kv_heads, n_steps * self.step, self.head_dim)
- new_k = mx.zeros(shape, keys.dtype)
- new_v = mx.zeros(shape, values.dtype)
+ k_shape = (1, self.n_kv_heads, n_steps * self.step, self.k_head_dim)
+ v_shape = (1, self.n_kv_heads, n_steps * self.step, self.v_head_dim)
+ new_k = mx.zeros(k_shape, keys.dtype)
+ new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
if prev % self.step != 0:
self.keys = self.keys[..., :prev, :]
diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py
new file mode 100644
index 00000000..308b94ba
--- /dev/null
+++ b/llms/mlx_lm/models/deepseek_v2.py
@@ -0,0 +1,468 @@
+import math
+from dataclasses import dataclass
+from typing import Dict, Optional, Tuple
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from .base import BaseModelArgs, KVCache
+from .switch_layers import SwitchGLU
+
+
+@dataclass
+class ModelArgs(BaseModelArgs):
+ model_type: str = "deepseek_v2"
+ vocab_size: int = 102400
+ hidden_size: int = 4096
+ intermediate_size: int = 11008
+ moe_intermediate_size: int = 1407
+ num_hidden_layers: int = 30
+ num_attention_heads: int = 32
+ num_key_value_heads: int = 32
+ n_shared_experts: Optional[int] = None
+ n_routed_experts: Optional[int] = None
+ routed_scaling_factor: float = 1.0
+ kv_lora_rank: int = 512
+ q_lora_rank: int = 1536
+ qk_rope_head_dim: int = 64
+ v_head_dim: int = 128
+ qk_nope_head_dim: int = 128
+ topk_method: str = "gready"
+ n_group: Optional[int] = None
+ topk_group: Optional[int] = None
+ num_experts_per_tok: Optional[int] = None
+ moe_layer_freq: int = 1
+ first_k_dense_replace: int = 0
+ max_position_embeddings: int = 2048
+ rms_norm_eps: float = 1e-6
+ rope_theta: float = 10000.0
+ rope_scaling: Optional[Dict] = None
+ attention_bias: bool = False
+
+
+def yarn_find_correction_dim(
+ num_rotations, dim, base=10000, max_position_embeddings=2048
+):
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
+ 2 * math.log(base)
+ )
+
+
+def yarn_find_correction_range(
+ low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
+):
+ low = math.floor(
+ yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
+ )
+ high = math.ceil(
+ yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
+ )
+ return max(low, 0), min(high, dim - 1)
+
+
+def yarn_get_mscale(scale=1, mscale=1):
+ if scale <= 1:
+ return 1.0
+ return 0.1 * mscale * math.log(scale) + 1.0
+
+
+def yarn_linear_ramp_mask(min, max, dim):
+ if min == max:
+ max += 0.001 # Prevent singularity
+
+ linear_func = (mx.arange(dim, dtype=mx.float32) - min) / (max - min)
+ ramp_func = mx.clip(linear_func, 0, 1)
+ return ramp_func
+
+
+class DeepseekV2YarnRotaryEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim,
+ max_position_embeddings=2048,
+ base=10000,
+ scaling_factor=1.0,
+ original_max_position_embeddings=4096,
+ beta_fast=32,
+ beta_slow=1,
+ mscale=1,
+ mscale_all_dim=0,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ self.scaling_factor = scaling_factor
+ self.original_max_position_embeddings = original_max_position_embeddings
+ self.beta_fast = beta_fast
+ self.beta_slow = beta_slow
+ self.mscale = mscale
+ self.mscale_all_dim = mscale_all_dim
+
+ self.max_seq_len_cached = None
+ self._cos_cached = None
+ self._sin_cached = None
+ self._inv_freq = None
+ self.set_cos_sin_cache(max_position_embeddings)
+
+ def set_cos_sin_cache(self, seq_len):
+ self.max_seq_len_cached = seq_len
+ dim = self.dim
+ freq_extra = 1.0 / (self.base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim))
+ freq_inter = 1.0 / (
+ self.scaling_factor
+ * self.base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
+ )
+
+ low, high = yarn_find_correction_range(
+ self.beta_fast,
+ self.beta_slow,
+ dim,
+ self.base,
+ self.original_max_position_embeddings,
+ )
+ inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
+ inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
+ self._inv_freq = inv_freq
+
+ t = mx.arange(seq_len, dtype=mx.float32)
+ freqs = mx.outer(t, inv_freq)
+
+ mscale = yarn_get_mscale(self.scaling_factor, self.mscale) / yarn_get_mscale(
+ self.scaling_factor, self.mscale_all_dim
+ )
+
+ self._cos_cached = mx.cos(freqs) * mscale
+ self._sin_cached = mx.sin(freqs) * mscale
+
+ def apply_rotary_pos_emb(self, x, cos, sin):
+ x1 = x[..., ::2]
+ x2 = x[..., 1::2]
+ rx1 = x1 * cos - x2 * sin
+ rx2 = x1 * sin + x2 * cos
+ return mx.concatenate([rx1, rx2], axis=-1)
+
+ def __call__(self, x, offset=0):
+ seq_len = offset + x.shape[2]
+ if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
+ self.set_cos_sin_cache(seq_len=seq_len)
+
+ if self._cos_cached.dtype != x.dtype:
+ self._cos_cached = self._cos_cached.astype(x.dtype)
+ self._sin_cached = self._sin_cached.astype(x.dtype)
+
+ return self.apply_rotary_pos_emb(
+ x,
+ self._cos_cached[offset:seq_len],
+ self._sin_cached[offset:seq_len],
+ )
+
+
+class DeepseekV2Attention(nn.Module):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.q_lora_rank = config.q_lora_rank
+ self.qk_rope_head_dim = config.qk_rope_head_dim
+ self.kv_lora_rank = config.kv_lora_rank
+ self.v_head_dim = config.v_head_dim
+ self.qk_nope_head_dim = config.qk_nope_head_dim
+ self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
+
+ self.scale = self.q_head_dim**-0.5
+
+ if self.q_lora_rank is None:
+ self.q_proj = nn.Linear(
+ self.hidden_size, self.num_heads * self.q_head_dim, bias=False
+ )
+ else:
+ self.q_a_proj = nn.Linear(
+ self.hidden_size, self.q_lora_rank, bias=config.attention_bias
+ )
+ self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank)
+ self.q_b_proj = nn.Linear(
+ self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
+ )
+
+ self.kv_a_proj_with_mqa = nn.Linear(
+ self.hidden_size,
+ self.kv_lora_rank + self.qk_rope_head_dim,
+ bias=config.attention_bias,
+ )
+ self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank)
+ self.kv_b_proj = nn.Linear(
+ self.kv_lora_rank,
+ self.num_heads
+ * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
+ bias=False,
+ )
+
+ self.o_proj = nn.Linear(
+ self.num_heads * self.v_head_dim,
+ self.hidden_size,
+ bias=config.attention_bias,
+ )
+
+ if self.config.rope_scaling is not None:
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
+ scaling_factor = self.config.rope_scaling["factor"]
+ if mscale_all_dim:
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
+ self.scale = self.scale * mscale * mscale
+
+ rope_kwargs = {
+ key: self.config.rope_scaling[key]
+ for key in [
+ "original_max_position_embeddings",
+ "beta_fast",
+ "beta_slow",
+ "mscale",
+ "mscale_all_dim",
+ ]
+ if key in self.config.rope_scaling
+ }
+ self.rope = DeepseekV2YarnRotaryEmbedding(
+ dim=self.qk_rope_head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ scaling_factor=scaling_factor,
+ base=self.rope_theta,
+ **rope_kwargs,
+ )
+
+ def __call__(
+ self,
+ x: mx.array,
+ mask: Optional[mx.array] = None,
+ cache: Optional[KVCache] = None,
+ ) -> mx.array:
+ B, L, D = x.shape
+
+ if self.q_lora_rank is None:
+ q = self.q_proj(x)
+ else:
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
+
+ q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
+ q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
+ compressed_kv = self.kv_a_proj_with_mqa(x)
+ compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
+ k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
+ kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
+ kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
+
+ k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
+
+ k_pe = mx.concatenate([k_pe] * self.num_heads, axis=1)
+
+ if cache is not None:
+ q_pe = self.rope(q_pe, cache.offset)
+ k_pe = self.rope(k_pe, cache.offset)
+ keys, values = cache.update_and_fetch(
+ mx.concatenate([k_nope, k_pe], axis=-1), values
+ )
+ else:
+ q_pe = self.rope(q_pe)
+ k_pe = self.rope(k_pe)
+ keys = mx.concatenate([k_nope, k_pe], axis=-1)
+
+ queries = mx.concatenate([q_nope, q_pe], axis=-1)
+
+ output = mx.fast.scaled_dot_product_attention(
+ queries, keys, values, scale=self.scale, mask=mask
+ )
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
+ return self.o_proj(output)
+
+
+class DeepseekV2MLP(nn.Module):
+ def __init__(
+ self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None
+ ):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
+ self.intermediate_size = (
+ config.intermediate_size if intermediate_size is None else intermediate_size
+ )
+
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+
+ def __call__(self, x):
+ down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class MoEGate(nn.Module):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+ self.config = config
+ self.top_k = config.num_experts_per_tok
+ self.n_routed_experts = config.n_routed_experts
+ self.routed_scaling_factor = config.routed_scaling_factor
+ self.topk_method = config.topk_method
+ self.n_group = config.n_group
+ self.topk_group = config.topk_group
+ self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
+
+ def __call__(self, x):
+ gates = x @ self.weight.T
+
+ scores = mx.softmax(gates, axis=-1, precise=True)
+
+ if self.topk_method == "group_limited_greedy":
+ bsz, seq_len = x.shape[:2]
+ scores = scores.reshape(bsz, seq_len, self.n_group, -1)
+ group_scores = scores.max(axis=-1)
+ k = self.n_group - self.topk_group
+ group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
+ batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
+ seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
+ scores[batch_idx, seq_idx, group_idx] = 0.0
+ scores = scores.reshape(bsz, seq_len, -1)
+
+ k = self.top_k
+ inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k])
+ scores = mx.take_along_axis(scores, inds, axis=-1)
+ scores = scores * self.routed_scaling_factor
+
+ return inds, scores
+
+
+class DeepseekV2MoE(nn.Module):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+ self.config = config
+ self.num_experts_per_tok = config.num_experts_per_tok
+ self.switch_mlp = SwitchGLU(
+ config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
+ )
+
+ self.gate = MoEGate(config)
+ if config.n_shared_experts is not None:
+ intermediate_size = config.moe_intermediate_size * config.n_shared_experts
+ self.shared_experts = DeepseekV2MLP(
+ config=config, intermediate_size=intermediate_size
+ )
+
+ def __call__(self, x):
+ inds, scores = self.gate(x)
+ y = self.switch_mlp(x, inds)
+ y = (y * scores[..., None]).sum(axis=-2)
+ if self.config.n_shared_experts is not None:
+ y = y + self.shared_experts(x)
+
+ return y
+
+
+class DeepseekV2DecoderLayer(nn.Module):
+ def __init__(self, config: ModelArgs, layer_idx: int):
+ super().__init__()
+ self.self_attn = DeepseekV2Attention(config)
+ self.mlp = (
+ DeepseekV2MoE(config)
+ if (
+ config.n_routed_experts is not None
+ and layer_idx >= config.first_k_dense_replace
+ and layer_idx % config.moe_layer_freq == 0
+ )
+ else DeepseekV2MLP(config)
+ )
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = nn.RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+
+ def __call__(
+ self,
+ x: mx.array,
+ mask: Optional[mx.array] = None,
+ cache: Optional[KVCache] = None,
+ ) -> mx.array:
+ r = self.self_attn(self.input_layernorm(x), mask, cache)
+ h = x + r
+ r = self.mlp(self.post_attention_layernorm(h))
+ out = h + r
+ return out
+
+
+class DeepseekV2Model(nn.Module):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+ self.vocab_size = config.vocab_size
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
+ self.layers = [
+ DeepseekV2DecoderLayer(config, idx)
+ for idx in range(config.num_hidden_layers)
+ ]
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def __call__(
+ self,
+ x: mx.array,
+ cache: Optional[KVCache] = None,
+ ) -> mx.array:
+ h = self.embed_tokens(x)
+ mask = None
+ T = h.shape[1]
+ if T > 1:
+ mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
+ mask = mask.astype(h.dtype)
+
+ if cache is None:
+ cache = [None] * len(self.layers)
+
+ for layer, c in zip(self.layers, cache):
+ h = layer(h, mask, c)
+
+ return self.norm(h)
+
+
+class Model(nn.Module):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+ self.args = config
+ self.model_type = config.model_type
+ self.model = DeepseekV2Model(config)
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ def __call__(
+ self,
+ inputs: mx.array,
+ cache: Optional[KVCache] = None,
+ ):
+ out = self.model(inputs, cache)
+ return self.lm_head(out)
+
+ def sanitize(self, weights):
+ for l in range(self.args.num_hidden_layers):
+ prefix = f"model.layers.{l}"
+ for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
+ for k in ["weight", "scales", "biases"]:
+ if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
+ to_join = [
+ weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
+ for e in range(self.args.n_routed_experts)
+ ]
+ weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
+ return weights
+
+ @property
+ def layers(self):
+ return self.model.layers
+
+ @property
+ def head_dim(self):
+ return (
+ self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
+ self.args.v_head_dim,
+ )
+
+ @property
+ def n_kv_heads(self):
+ return self.args.num_key_value_heads