chore: enable tie_word_embeddings config for qwen2 (#544)

This commit is contained in:
Anchen 2024-03-08 01:11:35 +11:00 committed by GitHub
parent b8e5eda4fd
commit 8a178f8716
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -21,6 +21,7 @@ class ModelArgs(BaseModelArgs):
rope_theta: float = 1000000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = False
def __post_init__(self):
if self.num_key_value_heads is None:
@ -176,7 +177,8 @@ class Model(nn.Module):
super().__init__()
self.model_type = args.model_type
self.model = Qwen2Model(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
@ -184,7 +186,11 @@ class Model(nn.Module):
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
if hasattr(self, "lm_head"):
return self.lm_head(out), cache
out = out @ self.model.embed_tokens.weight.T
return out, cache
@staticmethod
def sanitize(weights):