chore(mlx-lm): fix tie_word_embeddings for qwen2 (#566)

* chore: fix tie_word_embeddings for qwen2

* chore: default tie_word_embeddings to True
This commit is contained in:
Anchen 2024-03-13 15:34:32 +11:00 committed by GitHub
parent 39084e81c2
commit 3535408c99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 101 additions and 22 deletions

View File

@ -179,8 +179,7 @@ class Model(nn.Module):
out, cache = self.model(inputs, cache) out, cache = self.model(inputs, cache)
return self.lm_head(out), cache return self.lm_head(out), cache
@staticmethod def sanitize(self, weights):
def sanitize(weights):
# Remove unused precomputed rotary freqs # Remove unused precomputed rotary freqs
return { return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k

View File

@ -21,7 +21,7 @@ class ModelArgs(BaseModelArgs):
rope_theta: float = 1000000 rope_theta: float = 1000000
rope_traditional: bool = False rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = False tie_word_embeddings: bool = True
def __post_init__(self): def __post_init__(self):
if self.num_key_value_heads is None: if self.num_key_value_heads is None:
@ -168,10 +168,10 @@ class Qwen2Model(nn.Module):
class Model(nn.Module): class Model(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.args = args
self.model_type = args.model_type self.model_type = args.model_type
self.model = Qwen2Model(args) self.model = Qwen2Model(args)
if not args.tie_word_embeddings: self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__( def __call__(
self, self,
@ -179,14 +179,11 @@ class Model(nn.Module):
cache=None, cache=None,
): ):
out, cache = self.model(inputs, cache) out, cache = self.model(inputs, cache)
if hasattr(self, "lm_head"): return self.lm_head(out), cache
return self.lm_head(out), cache
out = out @ self.model.embed_tokens.weight.T def sanitize(self, weights):
return out, cache if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
@staticmethod
def sanitize(weights):
# Remove unused precomputed rotary freqs # Remove unused precomputed rotary freqs
return { return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k

View File

@ -147,11 +147,10 @@ class Starcoder2Model(nn.Module):
class Model(nn.Module): class Model(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.args = args
self.model_type = args.model_type self.model_type = args.model_type
self.model = Starcoder2Model(args) self.model = Starcoder2Model(args)
# For 15B starcoder2 and fine-tuned models which don't tie word embeddings self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__( def __call__(
self, self,
@ -159,11 +158,12 @@ class Model(nn.Module):
cache=None, cache=None,
): ):
out, cache = self.model(inputs, cache) out, cache = self.model(inputs, cache)
if not self.model.args.tie_word_embeddings: return self.lm_head(out), cache
return self.lm_head(out), cache
else: def sanitize(self, weights):
out = out @ self.model.embed_tokens.weight.T if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
return out, cache weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
return weights
@property @property
def layers(self): def layers(self):

View File

@ -319,12 +319,13 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
weights.update(mx.load(wf)) weights.update(mx.load(wf))
model_class, model_args_class = _get_classes(config=config) model_class, model_args_class = _get_classes(config=config)
if hasattr(model_class, "sanitize"):
weights = model_class.sanitize(weights)
model_args = model_args_class.from_dict(config) model_args = model_args_class.from_dict(config)
model = model_class(model_args) model = model_class(model_args)
if hasattr(model, "sanitize"):
weights = model.sanitize(weights)
if quantization is not None: if quantization is not None:
# for legacy models that don't have lm_head quant due to non-32 dims # for legacy models that don't have lm_head quant due to non-32 dims
if "lm_head.scales" not in weights.keys(): if "lm_head.scales" not in weights.keys():

View File

@ -129,6 +129,47 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers model, args.model_type, args.vocab_size, args.num_hidden_layers
) )
def test_qwen2_tie_word_embeddings_without_lm_head_weight(self):
from mlx_lm.models import qwen2
args = qwen2.ModelArgs(
model_type="qwen2",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
tie_word_embeddings=True,
)
model = qwen2.Model(args)
weights = {"model.embed_tokens.weight": "some_value"}
sanitized_weights = model.sanitize(weights)
self.assertIn("lm_head.weight", sanitized_weights)
self.assertEqual(sanitized_weights["lm_head.weight"], "some_value")
def test_qwen2_tie_word_embeddings_with_lm_head_weight(self):
from mlx_lm.models import qwen2
weights = {
"model.embed_tokens.weight": "some_value",
"lm_head.weight": "existing_value",
}
args = qwen2.ModelArgs(
model_type="qwen2",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
tie_word_embeddings=True,
)
model = qwen2.Model(args)
sanitized_weights = model.sanitize(weights)
self.assertIn("lm_head.weight", sanitized_weights)
self.assertEqual(sanitized_weights["lm_head.weight"], "existing_value")
def test_qwen(self): def test_qwen(self):
from mlx_lm.models import qwen from mlx_lm.models import qwen
@ -194,6 +235,47 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers model, args.model_type, args.vocab_size, args.num_hidden_layers
) )
def test_starcoder2_tie_word_embeddings_without_lm_head_weight(self):
from mlx_lm.models import starcoder2
args = starcoder2.ModelArgs(
model_type="starcoder2",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
num_key_value_heads=4,
tie_word_embeddings=True,
)
model = starcoder2.Model(args)
weights = {"model.embed_tokens.weight": "some_value"}
sanitized_weights = model.sanitize(weights)
self.assertIn("lm_head.weight", sanitized_weights)
self.assertEqual(sanitized_weights["lm_head.weight"], "some_value")
def test_starcoder2_tie_word_embeddings_with_lm_head_weight(self):
from mlx_lm.models import starcoder2
args = starcoder2.ModelArgs(
model_type="starcoder2",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
num_key_value_heads=4,
tie_word_embeddings=True,
)
model = starcoder2.Model(args)
weights = {
"model.embed_tokens.weight": "some_value",
"lm_head.weight": "existing_value",
}
sanitized_weights = model.sanitize(weights)
self.assertIn("lm_head.weight", sanitized_weights)
self.assertEqual(sanitized_weights["lm_head.weight"], "existing_value")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()