mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
chore(mlx-lm): fix tie_word_embeddings for qwen2 (#566)
* chore: fix tie_word_embeddings for qwen2 * chore: default tie_word_embeddings to True
This commit is contained in:
@@ -179,8 +179,7 @@ class Model(nn.Module):
|
||||
out, cache = self.model(inputs, cache)
|
||||
return self.lm_head(out), cache
|
||||
|
||||
@staticmethod
|
||||
def sanitize(weights):
|
||||
def sanitize(self, weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
|
@@ -21,7 +21,7 @@ class ModelArgs(BaseModelArgs):
|
||||
rope_theta: float = 1000000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = False
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
@@ -168,10 +168,10 @@ class Qwen2Model(nn.Module):
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = Qwen2Model(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -179,14 +179,11 @@ class Model(nn.Module):
|
||||
cache=None,
|
||||
):
|
||||
out, cache = self.model(inputs, cache)
|
||||
if hasattr(self, "lm_head"):
|
||||
return self.lm_head(out), cache
|
||||
return self.lm_head(out), cache
|
||||
|
||||
out = out @ self.model.embed_tokens.weight.T
|
||||
return out, cache
|
||||
|
||||
@staticmethod
|
||||
def sanitize(weights):
|
||||
def sanitize(self, weights):
|
||||
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
|
||||
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
|
@@ -147,11 +147,10 @@ class Starcoder2Model(nn.Module):
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = Starcoder2Model(args)
|
||||
# For 15B starcoder2 and fine-tuned models which don't tie word embeddings
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -159,11 +158,12 @@ class Model(nn.Module):
|
||||
cache=None,
|
||||
):
|
||||
out, cache = self.model(inputs, cache)
|
||||
if not self.model.args.tie_word_embeddings:
|
||||
return self.lm_head(out), cache
|
||||
else:
|
||||
out = out @ self.model.embed_tokens.weight.T
|
||||
return out, cache
|
||||
return self.lm_head(out), cache
|
||||
|
||||
def sanitize(self, weights):
|
||||
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
|
||||
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
|
Reference in New Issue
Block a user