mlx-examples/llms/mlx_lm/models/deepseek_v2.py
Anchen 561dcf5643
Add support for deepseek coder v2 lite (#882)
* feat: add support for deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct

* fix softmax + some cleanup

* more nits

* fix rope

* fix original_max_position_embeddings in rope

* fix original_max_position_embeddings in rope config

* add group greedy

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-07-17 07:23:28 -07:00

469 lines
16 KiB
Python

import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "deepseek_v2"
vocab_size: int = 102400
hidden_size: int = 4096
intermediate_size: int = 11008
moe_intermediate_size: int = 1407
num_hidden_layers: int = 30
num_attention_heads: int = 32
num_key_value_heads: int = 32
n_shared_experts: Optional[int] = None
n_routed_experts: Optional[int] = None
routed_scaling_factor: float = 1.0
kv_lora_rank: int = 512
q_lora_rank: int = 1536
qk_rope_head_dim: int = 64
v_head_dim: int = 128
qk_nope_head_dim: int = 128
topk_method: str = "gready"
n_group: Optional[int] = None
topk_group: Optional[int] = None
num_experts_per_tok: Optional[int] = None
moe_layer_freq: int = 1
first_k_dense_replace: int = 0
max_position_embeddings: int = 2048
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
rope_scaling: Optional[Dict] = None
attention_bias: bool = False
def yarn_find_correction_dim(
num_rotations, dim, base=10000, max_position_embeddings=2048
):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
)
def yarn_find_correction_range(
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
):
low = math.floor(
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
)
high = math.ceil(
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
)
return max(low, 0), min(high, dim - 1)
def yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
def yarn_linear_ramp_mask(min, max, dim):
if min == max:
max += 0.001 # Prevent singularity
linear_func = (mx.arange(dim, dtype=mx.float32) - min) / (max - min)
ramp_func = mx.clip(linear_func, 0, 1)
return ramp_func
class DeepseekV2YarnRotaryEmbedding(nn.Module):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
scaling_factor=1.0,
original_max_position_embeddings=4096,
beta_fast=32,
beta_slow=1,
mscale=1,
mscale_all_dim=0,
):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.scaling_factor = scaling_factor
self.original_max_position_embeddings = original_max_position_embeddings
self.beta_fast = beta_fast
self.beta_slow = beta_slow
self.mscale = mscale
self.mscale_all_dim = mscale_all_dim
self.max_seq_len_cached = None
self._cos_cached = None
self._sin_cached = None
self._inv_freq = None
self.set_cos_sin_cache(max_position_embeddings)
def set_cos_sin_cache(self, seq_len):
self.max_seq_len_cached = seq_len
dim = self.dim
freq_extra = 1.0 / (self.base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim))
freq_inter = 1.0 / (
self.scaling_factor
* self.base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
)
low, high = yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
dim,
self.base,
self.original_max_position_embeddings,
)
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
self._inv_freq = inv_freq
t = mx.arange(seq_len, dtype=mx.float32)
freqs = mx.outer(t, inv_freq)
mscale = yarn_get_mscale(self.scaling_factor, self.mscale) / yarn_get_mscale(
self.scaling_factor, self.mscale_all_dim
)
self._cos_cached = mx.cos(freqs) * mscale
self._sin_cached = mx.sin(freqs) * mscale
def apply_rotary_pos_emb(self, x, cos, sin):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * cos - x2 * sin
rx2 = x1 * sin + x2 * cos
return mx.concatenate([rx1, rx2], axis=-1)
def __call__(self, x, offset=0):
seq_len = offset + x.shape[2]
if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
self.set_cos_sin_cache(seq_len=seq_len)
if self._cos_cached.dtype != x.dtype:
self._cos_cached = self._cos_cached.astype(x.dtype)
self._sin_cached = self._sin_cached.astype(x.dtype)
return self.apply_rotary_pos_emb(
x,
self._cos_cached[offset:seq_len],
self._sin_cached[offset:seq_len],
)
class DeepseekV2Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.q_lora_rank = config.q_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim
self.kv_lora_rank = config.kv_lora_rank
self.v_head_dim = config.v_head_dim
self.qk_nope_head_dim = config.qk_nope_head_dim
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
self.scale = self.q_head_dim**-0.5
if self.q_lora_rank is None:
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.q_head_dim, bias=False
)
else:
self.q_a_proj = nn.Linear(
self.hidden_size, self.q_lora_rank, bias=config.attention_bias
)
self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank)
self.q_b_proj = nn.Linear(
self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
)
self.kv_a_proj_with_mqa = nn.Linear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=config.attention_bias,
)
self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank)
self.kv_b_proj = nn.Linear(
self.kv_lora_rank,
self.num_heads
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
bias=False,
)
self.o_proj = nn.Linear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=config.attention_bias,
)
if self.config.rope_scaling is not None:
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
scaling_factor = self.config.rope_scaling["factor"]
if mscale_all_dim:
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
self.scale = self.scale * mscale * mscale
rope_kwargs = {
key: self.config.rope_scaling[key]
for key in [
"original_max_position_embeddings",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
]
if key in self.config.rope_scaling
}
self.rope = DeepseekV2YarnRotaryEmbedding(
dim=self.qk_rope_head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
**rope_kwargs,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, D = x.shape
if self.q_lora_rank is None:
q = self.q_proj(x)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
compressed_kv = self.kv_a_proj_with_mqa(x)
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
k_pe = mx.concatenate([k_pe] * self.num_heads, axis=1)
if cache is not None:
q_pe = self.rope(q_pe, cache.offset)
k_pe = self.rope(k_pe, cache.offset)
keys, values = cache.update_and_fetch(
mx.concatenate([k_nope, k_pe], axis=-1), values
)
else:
q_pe = self.rope(q_pe)
k_pe = self.rope(k_pe)
keys = mx.concatenate([k_nope, k_pe], axis=-1)
queries = mx.concatenate([q_nope, q_pe], axis=-1)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class DeepseekV2MLP(nn.Module):
def __init__(
self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None
):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
self.intermediate_size = (
config.intermediate_size if intermediate_size is None else intermediate_size
)
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
def __call__(self, x):
down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class MoEGate(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.routed_scaling_factor = config.routed_scaling_factor
self.topk_method = config.topk_method
self.n_group = config.n_group
self.topk_group = config.topk_group
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
def __call__(self, x):
gates = x @ self.weight.T
scores = mx.softmax(gates, axis=-1, precise=True)
if self.topk_method == "group_limited_greedy":
bsz, seq_len = x.shape[:2]
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
group_scores = scores.max(axis=-1)
k = self.n_group - self.topk_group
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
scores[batch_idx, seq_idx, group_idx] = 0.0
scores = scores.reshape(bsz, seq_len, -1)
k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k])
scores = mx.take_along_axis(scores, inds, axis=-1)
scores = scores * self.routed_scaling_factor
return inds, scores
class DeepseekV2MoE(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.num_experts_per_tok = config.num_experts_per_tok
self.switch_mlp = SwitchGLU(
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
)
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV2MLP(
config=config, intermediate_size=intermediate_size
)
def __call__(self, x):
inds, scores = self.gate(x)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(x)
return y
class DeepseekV2DecoderLayer(nn.Module):
def __init__(self, config: ModelArgs, layer_idx: int):
super().__init__()
self.self_attn = DeepseekV2Attention(config)
self.mlp = (
DeepseekV2MoE(config)
if (
config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0
)
else DeepseekV2MLP(config)
)
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class DeepseekV2Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [
DeepseekV2DecoderLayer(config, idx)
for idx in range(config.num_hidden_layers)
]
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
x: mx.array,
cache: Optional[KVCache] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = None
T = h.shape[1]
if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.args = config
self.model_type = config.model_type
self.model = DeepseekV2Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache: Optional[KVCache] = None,
):
out = self.model(inputs, cache)
return self.lm_head(out)
def sanitize(self, weights):
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
for e in range(self.args.n_routed_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
return weights
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return (
self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
self.args.v_head_dim,
)
@property
def n_kv_heads(self):
return self.args.num_key_value_heads