From 1503bd4f550886092b156ec897e633b448bd78bc Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 8 Feb 2025 15:46:47 -0800 Subject: [PATCH] support hunyuan 7b (#1263) --- llms/mlx_lm/models/hunyuan.py | 37 ++++++++++++++++++++++++++++++----- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/llms/mlx_lm/models/hunyuan.py b/llms/mlx_lm/models/hunyuan.py index f9dc5652..122cebda 100644 --- a/llms/mlx_lm/models/hunyuan.py +++ b/llms/mlx_lm/models/hunyuan.py @@ -76,7 +76,6 @@ class Attention(nn.Module): head_dim = args.hidden_size // n_heads self.scale = head_dim**-0.5 - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) if kv_proj: self.k_proj = nn.Linear( @@ -107,7 +106,6 @@ class Attention(nn.Module): B, L, D = x.shape queries = self.q_proj(x) - if kv_states is None: keys, values = self.k_proj(x), self.v_proj(x) kv_states = keys, values @@ -198,7 +196,10 @@ class DecoderLayer(nn.Module): super().__init__() self.hidden_size = args.hidden_size self.self_attn = Attention(kv_proj, args) - self.mlp = MoeBlock(args) + if args.num_experts == 1: + self.mlp = MLP(args.hidden_size, args.intermediate_size) + else: + self.mlp = MoeBlock(args) self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.post_attention_layernorm = nn.RMSNorm( @@ -231,7 +232,10 @@ class HunYuanModel(nn.Module): assert self.vocab_size > 0 self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) self.layers = [ - DecoderLayer(args=args, kv_proj=(i % args.cla_share_factor) == 0) + DecoderLayer( + args=args, + kv_proj=(not args.use_cla) or (i % args.cla_share_factor) == 0, + ) for i in range(args.num_hidden_layers) ] self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) @@ -251,7 +255,7 @@ class HunYuanModel(nn.Module): cache = [None] * len(self.layers) for i, (layer, c) in enumerate(zip(self.layers, cache)): - if i % self.args.cla_share_factor == 0: + if (not self.args.use_cla) or i % self.args.cla_share_factor == 0: shared_kv_states = None h, shared_kv_states = layer(h, mask, c, shared_kv_states) @@ -275,6 +279,29 @@ class Model(nn.Module): return self.model.embed_tokens.as_linear(out) def sanitize(self, weights): + + if "model.layers.0.mlp.gate_and_up_proj.weight" in weights: + new_weights = {} + D = self.args.hidden_size + n_kv_heads = self.args.num_key_value_heads + n_kv_groups = self.args.num_attention_heads // n_kv_heads + head_dim = D // self.args.num_attention_heads + for k, v in weights.items(): + if "qkv_proj" in k: + v = v.reshape(n_kv_heads, n_kv_groups + 2, head_dim, -1) + splits = v.split([n_kv_groups, n_kv_groups + 1], axis=1) + for k_up, v_new in zip(["q_proj", "k_proj", "v_proj"], splits): + k_new = k.replace("qkv_proj", k_up) + new_weights[k_new] = mx.flatten(v_new, 0, 2) + elif "gate_and_up_proj" in k: + splits = v.split(2, axis=0) + for k_up, v_new in zip(["up_proj", "gate_proj"], splits): + k_new = k.replace("gate_and_up_proj", k_up) + new_weights[k_new] = v_new + else: + new_weights[k] = v + weights = new_weights + if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: return weights for l in range(self.args.num_hidden_layers):