mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
support hunyuan 7b (#1263)
This commit is contained in:
parent
31611b62d7
commit
1503bd4f55
@ -76,7 +76,6 @@ class Attention(nn.Module):
|
||||
|
||||
head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
|
||||
if kv_proj:
|
||||
self.k_proj = nn.Linear(
|
||||
@ -107,7 +106,6 @@ class Attention(nn.Module):
|
||||
B, L, D = x.shape
|
||||
|
||||
queries = self.q_proj(x)
|
||||
|
||||
if kv_states is None:
|
||||
keys, values = self.k_proj(x), self.v_proj(x)
|
||||
kv_states = keys, values
|
||||
@ -198,7 +196,10 @@ class DecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(kv_proj, args)
|
||||
self.mlp = MoeBlock(args)
|
||||
if args.num_experts == 1:
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
else:
|
||||
self.mlp = MoeBlock(args)
|
||||
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
@ -231,7 +232,10 @@ class HunYuanModel(nn.Module):
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
DecoderLayer(args=args, kv_proj=(i % args.cla_share_factor) == 0)
|
||||
DecoderLayer(
|
||||
args=args,
|
||||
kv_proj=(not args.use_cla) or (i % args.cla_share_factor) == 0,
|
||||
)
|
||||
for i in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
@ -251,7 +255,7 @@ class HunYuanModel(nn.Module):
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for i, (layer, c) in enumerate(zip(self.layers, cache)):
|
||||
if i % self.args.cla_share_factor == 0:
|
||||
if (not self.args.use_cla) or i % self.args.cla_share_factor == 0:
|
||||
shared_kv_states = None
|
||||
h, shared_kv_states = layer(h, mask, c, shared_kv_states)
|
||||
|
||||
@ -275,6 +279,29 @@ class Model(nn.Module):
|
||||
return self.model.embed_tokens.as_linear(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
|
||||
if "model.layers.0.mlp.gate_and_up_proj.weight" in weights:
|
||||
new_weights = {}
|
||||
D = self.args.hidden_size
|
||||
n_kv_heads = self.args.num_key_value_heads
|
||||
n_kv_groups = self.args.num_attention_heads // n_kv_heads
|
||||
head_dim = D // self.args.num_attention_heads
|
||||
for k, v in weights.items():
|
||||
if "qkv_proj" in k:
|
||||
v = v.reshape(n_kv_heads, n_kv_groups + 2, head_dim, -1)
|
||||
splits = v.split([n_kv_groups, n_kv_groups + 1], axis=1)
|
||||
for k_up, v_new in zip(["q_proj", "k_proj", "v_proj"], splits):
|
||||
k_new = k.replace("qkv_proj", k_up)
|
||||
new_weights[k_new] = mx.flatten(v_new, 0, 2)
|
||||
elif "gate_and_up_proj" in k:
|
||||
splits = v.split(2, axis=0)
|
||||
for k_up, v_new in zip(["up_proj", "gate_proj"], splits):
|
||||
k_new = k.replace("gate_and_up_proj", k_up)
|
||||
new_weights[k_new] = v_new
|
||||
else:
|
||||
new_weights[k] = v
|
||||
weights = new_weights
|
||||
|
||||
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
|
||||
return weights
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
|
Loading…
Reference in New Issue
Block a user