mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00

* Unify attention mask creation in LLMs. Currently, each model implementation in `mlx-examples/llms/models` has ad-hoc code to create a mask for the attention mechanism. This usually takes the form: ``` mask = None if h.shape[1] > 1: mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) mask = mask.astype(h.dtype) ``` This correctly creates a mask only if the input consists of more than one token. But this code assumes the multi-token input is at the beginning of inference. If, for example, we are evaluating multiple tokens because of speculative decoding or prompt cache reuse, this mask will not have the correct shape and and will cause the raising of an exception in the attention computation. Some of the models correctly implement the mask creation with code like this: ``` mask = None if h.shape[1] > 1: mask = create_additive_causal_mask( h.shape[1], cache[0].offset if cache is not None else 0 ) mask = mask.astype(h.dtype) ``` This commit unifies the attention mask creation for all models with a new function `create_attention_mask`, reducing code duplication and helping all models support inference performance enhancements like those mentioned above. * Allow batches in LLM key-value cache The current implementation of the LLM key-value cache assumes that the input batch is of size 1. Input batching (evaluating multiple alterative inputs at the same time) can be a valuable tool for speculative sampling and other techniques. This change removes the hard-coded batch size from the code that resizes the key-value cache. * Simplify causal mask creation Use the same codepath regardless of whether there's an offset or not. Addresses [this comment](https://github.com/ml-explore/mlx-examples/pull/911#discussion_r1691459717). * Use old-style type annotation to avoid linter error
465 lines
16 KiB
Python
465 lines
16 KiB
Python
import math
|
|
from dataclasses import dataclass
|
|
from typing import Dict, Optional, Tuple
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
from .base import BaseModelArgs, KVCache, create_attention_mask
|
|
from .switch_layers import SwitchGLU
|
|
|
|
|
|
@dataclass
|
|
class ModelArgs(BaseModelArgs):
|
|
model_type: str = "deepseek_v2"
|
|
vocab_size: int = 102400
|
|
hidden_size: int = 4096
|
|
intermediate_size: int = 11008
|
|
moe_intermediate_size: int = 1407
|
|
num_hidden_layers: int = 30
|
|
num_attention_heads: int = 32
|
|
num_key_value_heads: int = 32
|
|
n_shared_experts: Optional[int] = None
|
|
n_routed_experts: Optional[int] = None
|
|
routed_scaling_factor: float = 1.0
|
|
kv_lora_rank: int = 512
|
|
q_lora_rank: int = 1536
|
|
qk_rope_head_dim: int = 64
|
|
v_head_dim: int = 128
|
|
qk_nope_head_dim: int = 128
|
|
topk_method: str = "gready"
|
|
n_group: Optional[int] = None
|
|
topk_group: Optional[int] = None
|
|
num_experts_per_tok: Optional[int] = None
|
|
moe_layer_freq: int = 1
|
|
first_k_dense_replace: int = 0
|
|
max_position_embeddings: int = 2048
|
|
rms_norm_eps: float = 1e-6
|
|
rope_theta: float = 10000.0
|
|
rope_scaling: Optional[Dict] = None
|
|
attention_bias: bool = False
|
|
|
|
|
|
def yarn_find_correction_dim(
|
|
num_rotations, dim, base=10000, max_position_embeddings=2048
|
|
):
|
|
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
|
2 * math.log(base)
|
|
)
|
|
|
|
|
|
def yarn_find_correction_range(
|
|
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
|
|
):
|
|
low = math.floor(
|
|
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
|
)
|
|
high = math.ceil(
|
|
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
|
)
|
|
return max(low, 0), min(high, dim - 1)
|
|
|
|
|
|
def yarn_get_mscale(scale=1, mscale=1):
|
|
if scale <= 1:
|
|
return 1.0
|
|
return 0.1 * mscale * math.log(scale) + 1.0
|
|
|
|
|
|
def yarn_linear_ramp_mask(min, max, dim):
|
|
if min == max:
|
|
max += 0.001 # Prevent singularity
|
|
|
|
linear_func = (mx.arange(dim, dtype=mx.float32) - min) / (max - min)
|
|
ramp_func = mx.clip(linear_func, 0, 1)
|
|
return ramp_func
|
|
|
|
|
|
class DeepseekV2YarnRotaryEmbedding(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
max_position_embeddings=2048,
|
|
base=10000,
|
|
scaling_factor=1.0,
|
|
original_max_position_embeddings=4096,
|
|
beta_fast=32,
|
|
beta_slow=1,
|
|
mscale=1,
|
|
mscale_all_dim=0,
|
|
):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.base = base
|
|
self.scaling_factor = scaling_factor
|
|
self.original_max_position_embeddings = original_max_position_embeddings
|
|
self.beta_fast = beta_fast
|
|
self.beta_slow = beta_slow
|
|
self.mscale = mscale
|
|
self.mscale_all_dim = mscale_all_dim
|
|
|
|
self.max_seq_len_cached = None
|
|
self._cos_cached = None
|
|
self._sin_cached = None
|
|
self._inv_freq = None
|
|
self.set_cos_sin_cache(max_position_embeddings)
|
|
|
|
def set_cos_sin_cache(self, seq_len):
|
|
self.max_seq_len_cached = seq_len
|
|
dim = self.dim
|
|
freq_extra = 1.0 / (self.base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim))
|
|
freq_inter = 1.0 / (
|
|
self.scaling_factor
|
|
* self.base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
|
|
)
|
|
|
|
low, high = yarn_find_correction_range(
|
|
self.beta_fast,
|
|
self.beta_slow,
|
|
dim,
|
|
self.base,
|
|
self.original_max_position_embeddings,
|
|
)
|
|
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
|
|
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
|
self._inv_freq = inv_freq
|
|
|
|
t = mx.arange(seq_len, dtype=mx.float32)
|
|
freqs = mx.outer(t, inv_freq)
|
|
|
|
mscale = yarn_get_mscale(self.scaling_factor, self.mscale) / yarn_get_mscale(
|
|
self.scaling_factor, self.mscale_all_dim
|
|
)
|
|
|
|
self._cos_cached = mx.cos(freqs) * mscale
|
|
self._sin_cached = mx.sin(freqs) * mscale
|
|
|
|
def apply_rotary_pos_emb(self, x, cos, sin):
|
|
x1 = x[..., ::2]
|
|
x2 = x[..., 1::2]
|
|
rx1 = x1 * cos - x2 * sin
|
|
rx2 = x1 * sin + x2 * cos
|
|
return mx.concatenate([rx1, rx2], axis=-1)
|
|
|
|
def __call__(self, x, offset=0):
|
|
seq_len = offset + x.shape[2]
|
|
if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
|
|
self.set_cos_sin_cache(seq_len=seq_len)
|
|
|
|
if self._cos_cached.dtype != x.dtype:
|
|
self._cos_cached = self._cos_cached.astype(x.dtype)
|
|
self._sin_cached = self._sin_cached.astype(x.dtype)
|
|
|
|
return self.apply_rotary_pos_emb(
|
|
x,
|
|
self._cos_cached[offset:seq_len],
|
|
self._sin_cached[offset:seq_len],
|
|
)
|
|
|
|
|
|
class DeepseekV2Attention(nn.Module):
|
|
def __init__(self, config: ModelArgs):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.max_position_embeddings = config.max_position_embeddings
|
|
self.rope_theta = config.rope_theta
|
|
self.q_lora_rank = config.q_lora_rank
|
|
self.qk_rope_head_dim = config.qk_rope_head_dim
|
|
self.kv_lora_rank = config.kv_lora_rank
|
|
self.v_head_dim = config.v_head_dim
|
|
self.qk_nope_head_dim = config.qk_nope_head_dim
|
|
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
|
|
|
self.scale = self.q_head_dim**-0.5
|
|
|
|
if self.q_lora_rank is None:
|
|
self.q_proj = nn.Linear(
|
|
self.hidden_size, self.num_heads * self.q_head_dim, bias=False
|
|
)
|
|
else:
|
|
self.q_a_proj = nn.Linear(
|
|
self.hidden_size, self.q_lora_rank, bias=config.attention_bias
|
|
)
|
|
self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank)
|
|
self.q_b_proj = nn.Linear(
|
|
self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
|
|
)
|
|
|
|
self.kv_a_proj_with_mqa = nn.Linear(
|
|
self.hidden_size,
|
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
|
bias=config.attention_bias,
|
|
)
|
|
self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank)
|
|
self.kv_b_proj = nn.Linear(
|
|
self.kv_lora_rank,
|
|
self.num_heads
|
|
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
|
|
bias=False,
|
|
)
|
|
|
|
self.o_proj = nn.Linear(
|
|
self.num_heads * self.v_head_dim,
|
|
self.hidden_size,
|
|
bias=config.attention_bias,
|
|
)
|
|
|
|
if self.config.rope_scaling is not None:
|
|
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
|
scaling_factor = self.config.rope_scaling["factor"]
|
|
if mscale_all_dim:
|
|
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
|
self.scale = self.scale * mscale * mscale
|
|
|
|
rope_kwargs = {
|
|
key: self.config.rope_scaling[key]
|
|
for key in [
|
|
"original_max_position_embeddings",
|
|
"beta_fast",
|
|
"beta_slow",
|
|
"mscale",
|
|
"mscale_all_dim",
|
|
]
|
|
if key in self.config.rope_scaling
|
|
}
|
|
self.rope = DeepseekV2YarnRotaryEmbedding(
|
|
dim=self.qk_rope_head_dim,
|
|
max_position_embeddings=self.max_position_embeddings,
|
|
scaling_factor=scaling_factor,
|
|
base=self.rope_theta,
|
|
**rope_kwargs,
|
|
)
|
|
|
|
def __call__(
|
|
self,
|
|
x: mx.array,
|
|
mask: Optional[mx.array] = None,
|
|
cache: Optional[KVCache] = None,
|
|
) -> mx.array:
|
|
B, L, D = x.shape
|
|
|
|
if self.q_lora_rank is None:
|
|
q = self.q_proj(x)
|
|
else:
|
|
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
|
|
|
|
q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
|
|
q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
|
|
compressed_kv = self.kv_a_proj_with_mqa(x)
|
|
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
|
|
k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
|
|
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
|
kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
|
|
|
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
|
|
|
|
k_pe = mx.concatenate([k_pe] * self.num_heads, axis=1)
|
|
|
|
if cache is not None:
|
|
q_pe = self.rope(q_pe, cache.offset)
|
|
k_pe = self.rope(k_pe, cache.offset)
|
|
keys, values = cache.update_and_fetch(
|
|
mx.concatenate([k_nope, k_pe], axis=-1), values
|
|
)
|
|
else:
|
|
q_pe = self.rope(q_pe)
|
|
k_pe = self.rope(k_pe)
|
|
keys = mx.concatenate([k_nope, k_pe], axis=-1)
|
|
|
|
queries = mx.concatenate([q_nope, q_pe], axis=-1)
|
|
|
|
output = mx.fast.scaled_dot_product_attention(
|
|
queries, keys, values, scale=self.scale, mask=mask
|
|
)
|
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
return self.o_proj(output)
|
|
|
|
|
|
class DeepseekV2MLP(nn.Module):
|
|
def __init__(
|
|
self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
|
|
self.intermediate_size = (
|
|
config.intermediate_size if intermediate_size is None else intermediate_size
|
|
)
|
|
|
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
|
|
def __call__(self, x):
|
|
down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
|
return down_proj
|
|
|
|
|
|
class MoEGate(nn.Module):
|
|
def __init__(self, config: ModelArgs):
|
|
super().__init__()
|
|
self.config = config
|
|
self.top_k = config.num_experts_per_tok
|
|
self.n_routed_experts = config.n_routed_experts
|
|
self.routed_scaling_factor = config.routed_scaling_factor
|
|
self.topk_method = config.topk_method
|
|
self.n_group = config.n_group
|
|
self.topk_group = config.topk_group
|
|
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
|
|
|
|
def __call__(self, x):
|
|
gates = x @ self.weight.T
|
|
|
|
scores = mx.softmax(gates, axis=-1, precise=True)
|
|
|
|
if self.topk_method == "group_limited_greedy":
|
|
bsz, seq_len = x.shape[:2]
|
|
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
|
|
group_scores = scores.max(axis=-1)
|
|
k = self.n_group - self.topk_group
|
|
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
|
|
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
|
|
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
|
|
scores[batch_idx, seq_idx, group_idx] = 0.0
|
|
scores = scores.reshape(bsz, seq_len, -1)
|
|
|
|
k = self.top_k
|
|
inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k])
|
|
scores = mx.take_along_axis(scores, inds, axis=-1)
|
|
scores = scores * self.routed_scaling_factor
|
|
|
|
return inds, scores
|
|
|
|
|
|
class DeepseekV2MoE(nn.Module):
|
|
def __init__(self, config: ModelArgs):
|
|
super().__init__()
|
|
self.config = config
|
|
self.num_experts_per_tok = config.num_experts_per_tok
|
|
self.switch_mlp = SwitchGLU(
|
|
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
|
|
)
|
|
|
|
self.gate = MoEGate(config)
|
|
if config.n_shared_experts is not None:
|
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
|
self.shared_experts = DeepseekV2MLP(
|
|
config=config, intermediate_size=intermediate_size
|
|
)
|
|
|
|
def __call__(self, x):
|
|
inds, scores = self.gate(x)
|
|
y = self.switch_mlp(x, inds)
|
|
y = (y * scores[..., None]).sum(axis=-2)
|
|
if self.config.n_shared_experts is not None:
|
|
y = y + self.shared_experts(x)
|
|
|
|
return y
|
|
|
|
|
|
class DeepseekV2DecoderLayer(nn.Module):
|
|
def __init__(self, config: ModelArgs, layer_idx: int):
|
|
super().__init__()
|
|
self.self_attn = DeepseekV2Attention(config)
|
|
self.mlp = (
|
|
DeepseekV2MoE(config)
|
|
if (
|
|
config.n_routed_experts is not None
|
|
and layer_idx >= config.first_k_dense_replace
|
|
and layer_idx % config.moe_layer_freq == 0
|
|
)
|
|
else DeepseekV2MLP(config)
|
|
)
|
|
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_attention_layernorm = nn.RMSNorm(
|
|
config.hidden_size, eps=config.rms_norm_eps
|
|
)
|
|
|
|
def __call__(
|
|
self,
|
|
x: mx.array,
|
|
mask: Optional[mx.array] = None,
|
|
cache: Optional[KVCache] = None,
|
|
) -> mx.array:
|
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
|
h = x + r
|
|
r = self.mlp(self.post_attention_layernorm(h))
|
|
out = h + r
|
|
return out
|
|
|
|
|
|
class DeepseekV2Model(nn.Module):
|
|
def __init__(self, config: ModelArgs):
|
|
super().__init__()
|
|
self.vocab_size = config.vocab_size
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
self.layers = [
|
|
DeepseekV2DecoderLayer(config, idx)
|
|
for idx in range(config.num_hidden_layers)
|
|
]
|
|
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
def __call__(
|
|
self,
|
|
x: mx.array,
|
|
cache: Optional[KVCache] = None,
|
|
) -> mx.array:
|
|
h = self.embed_tokens(x)
|
|
mask = create_attention_mask(h, cache)
|
|
|
|
if cache is None:
|
|
cache = [None] * len(self.layers)
|
|
|
|
for layer, c in zip(self.layers, cache):
|
|
h = layer(h, mask, c)
|
|
|
|
return self.norm(h)
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, config: ModelArgs):
|
|
super().__init__()
|
|
self.args = config
|
|
self.model_type = config.model_type
|
|
self.model = DeepseekV2Model(config)
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
def __call__(
|
|
self,
|
|
inputs: mx.array,
|
|
cache: Optional[KVCache] = None,
|
|
):
|
|
out = self.model(inputs, cache)
|
|
return self.lm_head(out)
|
|
|
|
def sanitize(self, weights):
|
|
for l in range(self.args.num_hidden_layers):
|
|
prefix = f"model.layers.{l}"
|
|
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
|
|
for k in ["weight", "scales", "biases"]:
|
|
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
|
|
to_join = [
|
|
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
|
|
for e in range(self.args.n_routed_experts)
|
|
]
|
|
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
|
|
return weights
|
|
|
|
@property
|
|
def layers(self):
|
|
return self.model.layers
|
|
|
|
@property
|
|
def head_dim(self):
|
|
return (
|
|
self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
|
|
self.args.v_head_dim,
|
|
)
|
|
|
|
@property
|
|
def n_kv_heads(self):
|
|
return self.args.num_key_value_heads
|