# Copyright © 2024 Apple Inc. import math from dataclasses import dataclass from functools import partial from typing import Any, Dict, Optional, Tuple import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchGLU @dataclass class ModelArgs(BaseModelArgs): model_type: str = "deepseek_v3" vocab_size: int = 102400 hidden_size: int = 4096 intermediate_size: int = 11008 moe_intermediate_size: int = 1407 num_hidden_layers: int = 30 num_attention_heads: int = 32 num_key_value_heads: int = 32 n_shared_experts: Optional[int] = None n_routed_experts: Optional[int] = None routed_scaling_factor: float = 1.0 kv_lora_rank: int = 512 q_lora_rank: int = 1536 qk_rope_head_dim: int = 64 v_head_dim: int = 128 qk_nope_head_dim: int = 128 topk_method: str = "noaux_tc" scoring_func: str = "sigmoid" norm_topk_prob: bool = True n_group: Optional[int] = None topk_group: Optional[int] = None num_experts_per_tok: Optional[int] = None moe_layer_freq: int = 1 first_k_dense_replace: int = 0 max_position_embeddings: int = 2048 rms_norm_eps: float = 1e-6 rope_theta: float = 10000.0 rope_scaling: Dict = None attention_bias: bool = False def yarn_find_correction_dim( num_rotations, dim, base=10000, max_position_embeddings=2048 ): return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( 2 * math.log(base) ) def yarn_find_correction_range( low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 ): low = math.floor( yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) ) high = math.ceil( yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) ) return max(low, 0), min(high, dim - 1) def yarn_get_mscale(scale=1, mscale=1): if scale <= 1: return 1.0 return 0.1 * mscale * math.log(scale) + 1.0 def yarn_linear_ramp_mask(min_val, max_val, dim): if min_val == max_val: max_val += 0.001 # Prevent singularity linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val) return mx.clip(linear_func, 0, 1) class DeepseekV3YarnRotaryEmbedding(nn.Module): def __init__( self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0, original_max_position_embeddings=4096, beta_fast=32, beta_slow=1, mscale=1, mscale_all_dim=0, ): super().__init__() self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale( scaling_factor, mscale_all_dim ) freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim) freq_inter = scaling_factor * base ** ( mx.arange(0, dim, 2, dtype=mx.float32) / dim ) low, high = yarn_find_correction_range( beta_fast, beta_slow, dim, base, original_max_position_embeddings, ) freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) self._freqs = (freq_inter * freq_extra) / ( freq_inter * freq_mask + freq_extra * (1 - freq_mask) ) def __call__(self, x, offset=0): if self.mscale != 1.0: x = self.mscale * x return mx.fast.rope( x, x.shape[-1], traditional=True, base=None, scale=1.0, offset=offset, freqs=self._freqs, ) # A clipped silu to prevent fp16 from overflowing @partial(mx.compile, shapeless=True) def clipped_silu(x): return mx.clip(x * mx.sigmoid(x), a_min=-100, a_max=100) class DeepseekV3Attention(nn.Module): def __init__(self, config: ModelArgs): super().__init__() self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.q_lora_rank = config.q_lora_rank self.qk_rope_head_dim = config.qk_rope_head_dim self.kv_lora_rank = config.kv_lora_rank self.v_head_dim = config.v_head_dim self.qk_nope_head_dim = config.qk_nope_head_dim self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim self.scale = self.q_head_dim**-0.5 if self.q_lora_rank is None: self.q_proj = nn.Linear( self.hidden_size, self.num_heads * self.q_head_dim, bias=False ) else: self.q_a_proj = nn.Linear( self.hidden_size, self.q_lora_rank, bias=config.attention_bias ) self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank) self.q_b_proj = nn.Linear( self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False ) self.kv_a_proj_with_mqa = nn.Linear( self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=config.attention_bias, ) self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank) self.kv_b_proj = nn.Linear( self.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias=False, ) self.o_proj = nn.Linear( self.num_heads * self.v_head_dim, self.hidden_size, bias=config.attention_bias, ) mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) scaling_factor = self.config.rope_scaling["factor"] if mscale_all_dim: mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) self.scale = self.scale * mscale * mscale rope_kwargs = { key: self.config.rope_scaling[key] for key in [ "original_max_position_embeddings", "beta_fast", "beta_slow", "mscale", "mscale_all_dim", ] if key in self.config.rope_scaling } self.rope = DeepseekV3YarnRotaryEmbedding( dim=self.qk_rope_head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, base=self.rope_theta, **rope_kwargs, ) def __call__( self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape if self.q_lora_rank is None: q = self.q_proj(x) else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x))) q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3) q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1) compressed_kv = self.kv_a_proj_with_mqa(x) compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1) k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3) kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3) k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1) if cache is not None: q_pe = self.rope(q_pe, cache.offset) k_pe = self.rope(k_pe, cache.offset) k_pe = mx.repeat(k_pe, self.num_heads, axis=1) keys, values = cache.update_and_fetch( mx.concatenate([k_nope, k_pe], axis=-1), values ) else: q_pe = self.rope(q_pe) k_pe = self.rope(k_pe) k_pe = mx.repeat(k_pe, self.num_heads, axis=1) keys = mx.concatenate([k_nope, k_pe], axis=-1) queries = mx.concatenate([q_nope, q_pe], axis=-1) output = scaled_dot_product_attention( queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) class DeepseekV3MLP(nn.Module): def __init__( self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None ): super().__init__() self.config = config self.hidden_size = config.hidden_size if hidden_size is None else hidden_size self.intermediate_size = ( config.intermediate_size if intermediate_size is None else intermediate_size ) self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) def __call__(self, x): down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) return down_proj @mx.compile def group_expert_select( gates, e_score_correction_bias, top_k, n_group, topk_group, routed_scaling_factor, norm_topk_prob, ): k = top_k scores = mx.sigmoid(gates.astype(mx.float32)) scores = scores + e_score_correction_bias scores = mx.unflatten(scores, axis=-1, shape=(n_group, -1)) group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1, keepdims=True) k = n_group - topk_group group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :] scores = mx.put_along_axis(scores, group_idx, mx.array(0.0), axis=-2) scores = mx.flatten(scores, -2, -1) k = top_k inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] scores = mx.take_along_axis(scores, inds, axis=-1) if top_k > 1 and norm_topk_prob: denominator = scores.sum(axis=-1, keepdims=True) + 1e-20 scores = scores / denominator scores = scores * routed_scaling_factor return inds, scores class MoEGate(nn.Module): def __init__(self, config: ModelArgs): super().__init__() self.config = config self.top_k = config.num_experts_per_tok self.norm_topk_prob = config.norm_topk_prob self.n_routed_experts = config.n_routed_experts self.routed_scaling_factor = config.routed_scaling_factor self.n_group = config.n_group self.topk_group = config.topk_group self.weight = mx.zeros((self.n_routed_experts, config.hidden_size)) self.e_score_correction_bias = mx.zeros((self.n_routed_experts,)) assert config.topk_method == "noaux_tc", "Unsupported topk method." def __call__(self, x): return group_expert_select( x @ self.weight.T, self.e_score_correction_bias, self.top_k, self.n_group, self.topk_group, self.routed_scaling_factor, self.norm_topk_prob, ) class DeepseekV3MoE(nn.Module): def __init__(self, config: ModelArgs): super().__init__() self.config = config self.num_experts_per_tok = config.num_experts_per_tok self.switch_mlp = SwitchGLU( config.hidden_size, config.moe_intermediate_size, config.n_routed_experts, activation=clipped_silu, ) self.gate = MoEGate(config) if config.n_shared_experts is not None: intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = DeepseekV3MLP( config=config, intermediate_size=intermediate_size ) def __call__(self, x): inds, scores = self.gate(x) y = self.switch_mlp(x, inds) y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype) if self.config.n_shared_experts is not None: y = y + self.shared_experts(x) return y class DeepseekV3DecoderLayer(nn.Module): def __init__(self, config: ModelArgs, layer_idx: int): super().__init__() self.self_attn = DeepseekV3Attention(config) self.mlp = ( DeepseekV3MoE(config) if ( config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0 ) else DeepseekV3MLP(config) ) self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = nn.RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) def __call__( self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r r = self.mlp(self.post_attention_layernorm(h)) return h + r class DeepseekV3Model(nn.Module): def __init__(self, config: ModelArgs): super().__init__() self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = [ DeepseekV3DecoderLayer(config, idx) for idx in range(config.num_hidden_layers) ] self.start_idx = 0 self.end_idx = len(self.layers) self.num_layers = self.end_idx self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pipeline_rank = 0 self.pipeline_size = 1 def pipeline(self, group): # Split layers in reverse so rank=0 gets the last layers and # rank=pipeline_size-1 gets the first self.pipeline_rank = group.rank() self.pipeline_size = group.size() layers_per_rank = len(self.layers) // self.pipeline_size extra = len(self.layers) - layers_per_rank * self.pipeline_size if self.pipeline_rank < extra: layers_per_rank += 1 self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank self.end_idx = self.start_idx + layers_per_rank self.layers = self.layers[: self.end_idx] self.layers[: self.start_idx] = [None] * self.start_idx self.num_layers = len(self.layers) - self.start_idx def __call__( self, x: mx.array, cache: Optional[Any] = None, mask: Optional[mx.array] = None, ) -> mx.array: h = self.embed_tokens(x) pipeline_rank = self.pipeline_rank pipeline_size = self.pipeline_size # Hack to avoid time-outs during prompt-processing dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu if mask is None: mask = create_attention_mask(h, cache) if cache is None: cache = [None] * self.num_layers # Receive from the previous process in the pipeline if pipeline_rank < pipeline_size - 1: h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream) for i in range(self.num_layers): h = self.layers[self.start_idx + i](h, mask, cache[i]) # Send to the next process in the pipeline if pipeline_rank != 0: h = mx.distributed.send( h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream ) # Broadcast h while keeping it in the graph h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]] return self.norm(h) class Model(nn.Module): def __init__(self, config: ModelArgs): super().__init__() self.args = config self.model_type = config.model_type self.model = DeepseekV3Model(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def __call__( self, inputs: mx.array, cache: Optional[Any] = None, mask: Optional[mx.array] = None, ): out = self.model(inputs, cache, mask) return self.lm_head(out) def sanitize(self, weights): for l in range(self.args.num_hidden_layers): prefix = f"model.layers.{l}" for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]: for k in ["weight", "scales", "biases"]: if f"{prefix}.mlp.experts.0.{m}.{k}" in weights: to_join = [ weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") for e in range(self.args.n_routed_experts) ] weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join) # Remove multi-token prediction layer return {k: v for k, v in weights.items() if not k.startswith("model.layers.61")} @property def layers(self): return self.model.layers[self.model.start_idx : self.model.end_idx]