Files
mlx-examples/llms/mlx_lm/models/phimoe.py
2024-08-24 09:13:03 +02:00

249 lines
8.5 KiB
Python

import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .su_rope import SuScaledRotaryEmbedding
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "phimoe"
vocab_size: int = 32064
hidden_size: int = 4096
intermediate_size: int = 6400
num_hidden_layers: int = 32
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 131072
original_max_position_embeddings: int = 4096
initializer_range: float = 0.02
rms_norm_eps: float = 1e-6
pad_token_id: Optional[int] = None
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None
num_local_experts: int = 16
num_experts_per_tok: int = 2
attention_bias: bool = False
rope_theta: float = 10000.0
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"long_factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] not in ["longrope", "su", "linear"]:
print(
"[WARNING] rope_scaling 'type' currently only supports 'linear', 'su', and 'longrope'; setting rope scaling to false."
)
self.rope_scaling = None
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True)
rope_scale = 1.0
if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]:
self.rope = SuScaledRotaryEmbedding(
head_dim,
base=args.rope_theta,
max_position_embeddings=args.max_position_embeddings,
original_max_position_embeddings=args.original_max_position_embeddings,
short_factor=args.rope_scaling["short_factor"],
long_factor=args.rope_scaling["long_factor"],
)
else:
if args.rope_scaling and args.rope_scaling["type"] == "linear":
assert isinstance(args.rope_scaling["factor"], float)
rope_scale = 1 / args.rope_scaling["factor"]
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class PhiMoESparseMoeBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_dim = args.hidden_size
self.ffn_dim = args.intermediate_size
self.num_experts = args.num_local_experts
self.top_k = args.num_experts_per_tok
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
self.switch_mlp = SwitchGLU(self.hidden_dim, self.ffn_dim, self.num_experts)
def __call__(self, x: mx.array) -> mx.array:
gates = self.gate(x)
k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
scores = mx.take_along_axis(gates, inds, axis=-1)
scores = mx.softmax(scores, axis=-1, precise=True)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
return y
class PhiMoEDecoderLayer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.block_sparse_moe = PhiMoESparseMoeBlock(args)
self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
residual = x
hidden_states = self.input_layernorm(x)
hidden_states = self.self_attn(x=hidden_states, mask=mask, cache=cache)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.block_sparse_moe(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class PhiMoEModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [PhiMoEDecoderLayer(args) for _ in range(args.num_hidden_layers)]
self.norm = nn.LayerNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
) -> mx.array:
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model = PhiMoEModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
return self.lm_head(out)
def sanitize(self, weights):
if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights:
return weights
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.block_sparse_moe.experts.0.{n}.{k}" in weights:
to_join = [
weights.pop(
f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}"
)
for e in range(self.args.num_local_experts)
]
weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = (
mx.stack(to_join)
)
return weights
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads