chore: enable tie_word_embeddings config for qwen2 (#544)

This commit is contained in:
Anchen 2024-03-08 01:11:35 +11:00 committed by GitHub
parent b8e5eda4fd
commit 8a178f8716
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -21,6 +21,7 @@ class ModelArgs(BaseModelArgs):
rope_theta: float = 1000000 rope_theta: float = 1000000
rope_traditional: bool = False rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = False
def __post_init__(self): def __post_init__(self):
if self.num_key_value_heads is None: if self.num_key_value_heads is None:
@ -176,6 +177,7 @@ class Model(nn.Module):
super().__init__() super().__init__()
self.model_type = args.model_type self.model_type = args.model_type
self.model = Qwen2Model(args) self.model = Qwen2Model(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__( def __call__(
@ -184,8 +186,12 @@ class Model(nn.Module):
cache=None, cache=None,
): ):
out, cache = self.model(inputs, cache) out, cache = self.model(inputs, cache)
if hasattr(self, "lm_head"):
return self.lm_head(out), cache return self.lm_head(out), cache
out = out @ self.model.embed_tokens.weight.T
return out, cache
@staticmethod @staticmethod
def sanitize(weights): def sanitize(weights):
# Remove unused precomputed rotary freqs # Remove unused precomputed rotary freqs