mlx-examples/llms/mlx_lm/models/deepseek.py

259 lines
8.5 KiB
Python
Raw Normal View History

from dataclasses import dataclass
from typing import Any, Dict, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "deepseek"
vocab_size: int = 102400
hidden_size: int = 4096
intermediate_size: int = 11008
moe_intermediate_size: int = 1407
num_hidden_layers: int = 30
num_attention_heads: int = 32
num_key_value_heads: int = 32
n_shared_experts: Optional[int] = None
n_routed_experts: Optional[int] = None
num_experts_per_tok: Optional[int] = None
moe_layer_freq: int = 1
first_k_dense_replace: int = 0
max_position_embeddings: int = 2048
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
rope_scaling: Optional[Dict] = None
attention_bias: bool = False
class DeepseekAttention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.scale = self.head_dim**-0.5
attention_bias = getattr(config, "attention_bias", False)
self.q_proj = nn.Linear(
self.hidden_size,
config.num_attention_heads * self.head_dim,
bias=attention_bias,
)
self.k_proj = nn.Linear(
self.hidden_size,
config.num_key_value_heads * self.head_dim,
bias=attention_bias,
)
self.v_proj = nn.Linear(
self.hidden_size,
config.num_key_value_heads * self.head_dim,
bias=attention_bias,
)
self.o_proj = nn.Linear(
self.hidden_size,
config.num_attention_heads * self.head_dim,
bias=attention_bias,
)
rope_scale = 1.0
if config.rope_scaling and config.rope_scaling["type"] == "linear":
assert isinstance(config.rope_scaling["factor"], float)
rope_scale = 1 / config.rope_scaling["factor"]
self.rope = nn.RoPE(
self.head_dim,
base=config.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, _ = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose(
0, 2, 1, 3
)
keys = keys.reshape(B, L, self.num_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.num_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class DeepseekMLP(nn.Module):
def __init__(
self,
config: ModelArgs,
hidden_size: Optional[int] = None,
intermediate_size: Optional[int] = None,
):
super().__init__()
self.config = config
self.hidden_size = hidden_size or config.hidden_size
self.intermediate_size = intermediate_size or config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = nn.silu
def __call__(self, x: mx.array) -> mx.array:
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class MoEGate(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
def __call__(self, x):
gates = x @ self.weight.T
scores = mx.softmax(gates, axis=-1, precise=True)
k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k])
scores = mx.take_along_axis(scores, inds, axis=-1)
return inds, scores
class DeepseekMoE(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.switch_mlp = SwitchGLU(
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
)
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekMLP(
config=config, intermediate_size=intermediate_size
)
def __call__(self, x):
inds, scores = self.gate(x)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(x)
return y
class DeepseekDecoderLayer(nn.Module):
def __init__(self, config: ModelArgs, layer_idx: int):
super().__init__()
self.self_attn = DeepseekAttention(config)
self.mlp = (
DeepseekMoE(config)
if (
config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0
)
else DeepseekMLP(config)
)
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class DeepseekModel(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [
DeepseekDecoderLayer(config, idx) for idx in range(config.num_hidden_layers)
]
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
x: mx.array,
cache: Optional[Any] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.args = config
self.model_type = config.model_type
self.model = DeepseekModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
):
out = self.model(inputs, cache)
return self.lm_head(out)
def sanitize(self, weights):
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for m in ["gate_proj", "down_proj", "up_proj"]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
for e in range(self.args.n_routed_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
return weights
@property
def layers(self):
return self.model.layers