mlx-examples/llms/mlx_lm/models/deepseek.py
Awni Hannun fca087be49
More cache improvements (#1015)
* fix rotating kv cache for chat use case

* reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat

* nit in chat

* fix tests

* fix tests

* fix tests

* docs

* chat command

* comments + docs

* Define meta_state on all Cache implementations

* fixes + trim_prompt_cache api

* fix default model

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-10-07 20:45:51 -07:00

259 lines
8.5 KiB
Python

from dataclasses import dataclass
from typing import Any, Dict, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "deepseek"
vocab_size: int = 102400
hidden_size: int = 4096
intermediate_size: int = 11008
moe_intermediate_size: int = 1407
num_hidden_layers: int = 30
num_attention_heads: int = 32
num_key_value_heads: int = 32
n_shared_experts: Optional[int] = None
n_routed_experts: Optional[int] = None
num_experts_per_tok: Optional[int] = None
moe_layer_freq: int = 1
first_k_dense_replace: int = 0
max_position_embeddings: int = 2048
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
rope_scaling: Optional[Dict] = None
attention_bias: bool = False
class DeepseekAttention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.scale = self.head_dim**-0.5
attention_bias = getattr(config, "attention_bias", False)
self.q_proj = nn.Linear(
self.hidden_size,
config.num_attention_heads * self.head_dim,
bias=attention_bias,
)
self.k_proj = nn.Linear(
self.hidden_size,
config.num_key_value_heads * self.head_dim,
bias=attention_bias,
)
self.v_proj = nn.Linear(
self.hidden_size,
config.num_key_value_heads * self.head_dim,
bias=attention_bias,
)
self.o_proj = nn.Linear(
self.hidden_size,
config.num_attention_heads * self.head_dim,
bias=attention_bias,
)
rope_scale = 1.0
if config.rope_scaling and config.rope_scaling["type"] == "linear":
assert isinstance(config.rope_scaling["factor"], float)
rope_scale = 1 / config.rope_scaling["factor"]
self.rope = nn.RoPE(
self.head_dim,
base=config.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, _ = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose(
0, 2, 1, 3
)
keys = keys.reshape(B, L, self.num_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.num_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class DeepseekMLP(nn.Module):
def __init__(
self,
config: ModelArgs,
hidden_size: Optional[int] = None,
intermediate_size: Optional[int] = None,
):
super().__init__()
self.config = config
self.hidden_size = hidden_size or config.hidden_size
self.intermediate_size = intermediate_size or config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = nn.silu
def __call__(self, x: mx.array) -> mx.array:
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class MoEGate(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
def __call__(self, x):
gates = x @ self.weight.T
scores = mx.softmax(gates, axis=-1, precise=True)
k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k])
scores = mx.take_along_axis(scores, inds, axis=-1)
return inds, scores
class DeepseekMoE(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.switch_mlp = SwitchGLU(
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
)
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekMLP(
config=config, intermediate_size=intermediate_size
)
def __call__(self, x):
inds, scores = self.gate(x)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(x)
return y
class DeepseekDecoderLayer(nn.Module):
def __init__(self, config: ModelArgs, layer_idx: int):
super().__init__()
self.self_attn = DeepseekAttention(config)
self.mlp = (
DeepseekMoE(config)
if (
config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0
)
else DeepseekMLP(config)
)
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class DeepseekModel(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [
DeepseekDecoderLayer(config, idx) for idx in range(config.num_hidden_layers)
]
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
x: mx.array,
cache: Optional[Any] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.args = config
self.model_type = config.model_type
self.model = DeepseekModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
):
out = self.model(inputs, cache)
return self.lm_head(out)
def sanitize(self, weights):
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for m in ["gate_proj", "down_proj", "up_proj"]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
for e in range(self.args.n_routed_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
return weights
@property
def layers(self):
return self.model.layers