chore(mlx-lm): fix tie_word_embeddings for qwen2 (#566)

* chore: fix tie_word_embeddings for qwen2

* chore: default tie_word_embeddings to True
This commit is contained in:
Anchen
2024-03-13 15:34:32 +11:00
committed by GitHub
parent 39084e81c2
commit 3535408c99
5 changed files with 101 additions and 22 deletions

View File

@@ -179,8 +179,7 @@ class Model(nn.Module):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
@staticmethod
def sanitize(weights):
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k

View File

@@ -21,7 +21,7 @@ class ModelArgs(BaseModelArgs):
rope_theta: float = 1000000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = False
tie_word_embeddings: bool = True
def __post_init__(self):
if self.num_key_value_heads is None:
@@ -168,10 +168,10 @@ class Qwen2Model(nn.Module):
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = Qwen2Model(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
@@ -179,14 +179,11 @@ class Model(nn.Module):
cache=None,
):
out, cache = self.model(inputs, cache)
if hasattr(self, "lm_head"):
return self.lm_head(out), cache
return self.lm_head(out), cache
out = out @ self.model.embed_tokens.weight.T
return out, cache
@staticmethod
def sanitize(weights):
def sanitize(self, weights):
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k

View File

@@ -147,11 +147,10 @@ class Starcoder2Model(nn.Module):
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = Starcoder2Model(args)
# For 15B starcoder2 and fine-tuned models which don't tie word embeddings
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
@@ -159,11 +158,12 @@ class Model(nn.Module):
cache=None,
):
out, cache = self.model(inputs, cache)
if not self.model.args.tie_word_embeddings:
return self.lm_head(out), cache
else:
out = out @ self.model.embed_tokens.weight.T
return out, cache
return self.lm_head(out), cache
def sanitize(self, weights):
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
return weights
@property
def layers(self):