support hunyuan 7b (#1263)

This commit is contained in:
Awni Hannun 2025-02-08 15:46:47 -08:00 committed by GitHub
parent 31611b62d7
commit 1503bd4f55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -76,7 +76,6 @@ class Attention(nn.Module):
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
if kv_proj:
self.k_proj = nn.Linear(
@ -107,7 +106,6 @@ class Attention(nn.Module):
B, L, D = x.shape
queries = self.q_proj(x)
if kv_states is None:
keys, values = self.k_proj(x), self.v_proj(x)
kv_states = keys, values
@ -198,7 +196,10 @@ class DecoderLayer(nn.Module):
super().__init__()
self.hidden_size = args.hidden_size
self.self_attn = Attention(kv_proj, args)
self.mlp = MoeBlock(args)
if args.num_experts == 1:
self.mlp = MLP(args.hidden_size, args.intermediate_size)
else:
self.mlp = MoeBlock(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
@ -231,7 +232,10 @@ class HunYuanModel(nn.Module):
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
DecoderLayer(args=args, kv_proj=(i % args.cla_share_factor) == 0)
DecoderLayer(
args=args,
kv_proj=(not args.use_cla) or (i % args.cla_share_factor) == 0,
)
for i in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
@ -251,7 +255,7 @@ class HunYuanModel(nn.Module):
cache = [None] * len(self.layers)
for i, (layer, c) in enumerate(zip(self.layers, cache)):
if i % self.args.cla_share_factor == 0:
if (not self.args.use_cla) or i % self.args.cla_share_factor == 0:
shared_kv_states = None
h, shared_kv_states = layer(h, mask, c, shared_kv_states)
@ -275,6 +279,29 @@ class Model(nn.Module):
return self.model.embed_tokens.as_linear(out)
def sanitize(self, weights):
if "model.layers.0.mlp.gate_and_up_proj.weight" in weights:
new_weights = {}
D = self.args.hidden_size
n_kv_heads = self.args.num_key_value_heads
n_kv_groups = self.args.num_attention_heads // n_kv_heads
head_dim = D // self.args.num_attention_heads
for k, v in weights.items():
if "qkv_proj" in k:
v = v.reshape(n_kv_heads, n_kv_groups + 2, head_dim, -1)
splits = v.split([n_kv_groups, n_kv_groups + 1], axis=1)
for k_up, v_new in zip(["q_proj", "k_proj", "v_proj"], splits):
k_new = k.replace("qkv_proj", k_up)
new_weights[k_new] = mx.flatten(v_new, 0, 2)
elif "gate_and_up_proj" in k:
splits = v.split(2, axis=0)
for k_up, v_new in zip(["up_proj", "gate_proj"], splits):
k_new = k.replace("gate_and_up_proj", k_up)
new_weights[k_new] = v_new
else:
new_weights[k] = v
weights = new_weights
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
return weights
for l in range(self.args.num_hidden_layers):