mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
support hunyuan 7b (#1263)
This commit is contained in:
parent
31611b62d7
commit
1503bd4f55
@ -76,7 +76,6 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
head_dim = args.hidden_size // n_heads
|
head_dim = args.hidden_size // n_heads
|
||||||
self.scale = head_dim**-0.5
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
|
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
|
||||||
if kv_proj:
|
if kv_proj:
|
||||||
self.k_proj = nn.Linear(
|
self.k_proj = nn.Linear(
|
||||||
@ -107,7 +106,6 @@ class Attention(nn.Module):
|
|||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
queries = self.q_proj(x)
|
queries = self.q_proj(x)
|
||||||
|
|
||||||
if kv_states is None:
|
if kv_states is None:
|
||||||
keys, values = self.k_proj(x), self.v_proj(x)
|
keys, values = self.k_proj(x), self.v_proj(x)
|
||||||
kv_states = keys, values
|
kv_states = keys, values
|
||||||
@ -198,7 +196,10 @@ class DecoderLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = args.hidden_size
|
self.hidden_size = args.hidden_size
|
||||||
self.self_attn = Attention(kv_proj, args)
|
self.self_attn = Attention(kv_proj, args)
|
||||||
self.mlp = MoeBlock(args)
|
if args.num_experts == 1:
|
||||||
|
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||||
|
else:
|
||||||
|
self.mlp = MoeBlock(args)
|
||||||
|
|
||||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
self.post_attention_layernorm = nn.RMSNorm(
|
self.post_attention_layernorm = nn.RMSNorm(
|
||||||
@ -231,7 +232,10 @@ class HunYuanModel(nn.Module):
|
|||||||
assert self.vocab_size > 0
|
assert self.vocab_size > 0
|
||||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
self.layers = [
|
self.layers = [
|
||||||
DecoderLayer(args=args, kv_proj=(i % args.cla_share_factor) == 0)
|
DecoderLayer(
|
||||||
|
args=args,
|
||||||
|
kv_proj=(not args.use_cla) or (i % args.cla_share_factor) == 0,
|
||||||
|
)
|
||||||
for i in range(args.num_hidden_layers)
|
for i in range(args.num_hidden_layers)
|
||||||
]
|
]
|
||||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
@ -251,7 +255,7 @@ class HunYuanModel(nn.Module):
|
|||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
for i, (layer, c) in enumerate(zip(self.layers, cache)):
|
for i, (layer, c) in enumerate(zip(self.layers, cache)):
|
||||||
if i % self.args.cla_share_factor == 0:
|
if (not self.args.use_cla) or i % self.args.cla_share_factor == 0:
|
||||||
shared_kv_states = None
|
shared_kv_states = None
|
||||||
h, shared_kv_states = layer(h, mask, c, shared_kv_states)
|
h, shared_kv_states = layer(h, mask, c, shared_kv_states)
|
||||||
|
|
||||||
@ -275,6 +279,29 @@ class Model(nn.Module):
|
|||||||
return self.model.embed_tokens.as_linear(out)
|
return self.model.embed_tokens.as_linear(out)
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
|
|
||||||
|
if "model.layers.0.mlp.gate_and_up_proj.weight" in weights:
|
||||||
|
new_weights = {}
|
||||||
|
D = self.args.hidden_size
|
||||||
|
n_kv_heads = self.args.num_key_value_heads
|
||||||
|
n_kv_groups = self.args.num_attention_heads // n_kv_heads
|
||||||
|
head_dim = D // self.args.num_attention_heads
|
||||||
|
for k, v in weights.items():
|
||||||
|
if "qkv_proj" in k:
|
||||||
|
v = v.reshape(n_kv_heads, n_kv_groups + 2, head_dim, -1)
|
||||||
|
splits = v.split([n_kv_groups, n_kv_groups + 1], axis=1)
|
||||||
|
for k_up, v_new in zip(["q_proj", "k_proj", "v_proj"], splits):
|
||||||
|
k_new = k.replace("qkv_proj", k_up)
|
||||||
|
new_weights[k_new] = mx.flatten(v_new, 0, 2)
|
||||||
|
elif "gate_and_up_proj" in k:
|
||||||
|
splits = v.split(2, axis=0)
|
||||||
|
for k_up, v_new in zip(["up_proj", "gate_proj"], splits):
|
||||||
|
k_new = k.replace("gate_and_up_proj", k_up)
|
||||||
|
new_weights[k_new] = v_new
|
||||||
|
else:
|
||||||
|
new_weights[k] = v
|
||||||
|
weights = new_weights
|
||||||
|
|
||||||
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
|
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
|
||||||
return weights
|
return weights
|
||||||
for l in range(self.args.num_hidden_layers):
|
for l in range(self.args.num_hidden_layers):
|
||||||
|
Loading…
Reference in New Issue
Block a user