From 3535408c999a17dbd009f44e579b672b8e59c005 Mon Sep 17 00:00:00 2001
From: Anchen
Date: Wed, 13 Mar 2024 15:34:32 +1100
Subject: [PATCH] chore(mlx-lm): fix tie_word_embeddings for qwen2 (#566)
* chore: fix tie_word_embeddings for qwen2
* chore: default tie_word_embeddings to True
---
llms/mlx_lm/models/llama.py | 3 +-
llms/mlx_lm/models/qwen2.py | 17 +++----
llms/mlx_lm/models/starcoder2.py | 16 +++----
llms/mlx_lm/utils.py | 5 +-
llms/tests/test_models.py | 82 ++++++++++++++++++++++++++++++++
5 files changed, 101 insertions(+), 22 deletions(-)
diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py
index 817e5cc1..972f381e 100644
--- a/llms/mlx_lm/models/llama.py
+++ b/llms/mlx_lm/models/llama.py
@@ -179,8 +179,7 @@ class Model(nn.Module):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
- @staticmethod
- def sanitize(weights):
+ def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py
index 3773cbfb..1f196b0f 100644
--- a/llms/mlx_lm/models/qwen2.py
+++ b/llms/mlx_lm/models/qwen2.py
@@ -21,7 +21,7 @@ class ModelArgs(BaseModelArgs):
rope_theta: float = 1000000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
- tie_word_embeddings: bool = False
+ tie_word_embeddings: bool = True
def __post_init__(self):
if self.num_key_value_heads is None:
@@ -168,10 +168,10 @@ class Qwen2Model(nn.Module):
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
+ self.args = args
self.model_type = args.model_type
self.model = Qwen2Model(args)
- if not args.tie_word_embeddings:
- self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
@@ -179,14 +179,11 @@ class Model(nn.Module):
cache=None,
):
out, cache = self.model(inputs, cache)
- if hasattr(self, "lm_head"):
- return self.lm_head(out), cache
+ return self.lm_head(out), cache
- out = out @ self.model.embed_tokens.weight.T
- return out, cache
-
- @staticmethod
- def sanitize(weights):
+ def sanitize(self, weights):
+ if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
+ weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py
index aefcf88c..6b10f716 100644
--- a/llms/mlx_lm/models/starcoder2.py
+++ b/llms/mlx_lm/models/starcoder2.py
@@ -147,11 +147,10 @@ class Starcoder2Model(nn.Module):
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
+ self.args = args
self.model_type = args.model_type
self.model = Starcoder2Model(args)
- # For 15B starcoder2 and fine-tuned models which don't tie word embeddings
- if not args.tie_word_embeddings:
- self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
@@ -159,11 +158,12 @@ class Model(nn.Module):
cache=None,
):
out, cache = self.model(inputs, cache)
- if not self.model.args.tie_word_embeddings:
- return self.lm_head(out), cache
- else:
- out = out @ self.model.embed_tokens.weight.T
- return out, cache
+ return self.lm_head(out), cache
+
+ def sanitize(self, weights):
+ if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
+ weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
+ return weights
@property
def layers(self):
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index 6adcf924..bfbe911b 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -319,12 +319,13 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
weights.update(mx.load(wf))
model_class, model_args_class = _get_classes(config=config)
- if hasattr(model_class, "sanitize"):
- weights = model_class.sanitize(weights)
model_args = model_args_class.from_dict(config)
model = model_class(model_args)
+ if hasattr(model, "sanitize"):
+ weights = model.sanitize(weights)
+
if quantization is not None:
# for legacy models that don't have lm_head quant due to non-32 dims
if "lm_head.scales" not in weights.keys():
diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py
index 7dcc82b2..7b99d065 100644
--- a/llms/tests/test_models.py
+++ b/llms/tests/test_models.py
@@ -129,6 +129,47 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
+ def test_qwen2_tie_word_embeddings_without_lm_head_weight(self):
+ from mlx_lm.models import qwen2
+
+ args = qwen2.ModelArgs(
+ model_type="qwen2",
+ hidden_size=1024,
+ num_hidden_layers=4,
+ intermediate_size=2048,
+ num_attention_heads=4,
+ rms_norm_eps=1e-5,
+ vocab_size=10_000,
+ tie_word_embeddings=True,
+ )
+ model = qwen2.Model(args)
+ weights = {"model.embed_tokens.weight": "some_value"}
+ sanitized_weights = model.sanitize(weights)
+ self.assertIn("lm_head.weight", sanitized_weights)
+ self.assertEqual(sanitized_weights["lm_head.weight"], "some_value")
+
+ def test_qwen2_tie_word_embeddings_with_lm_head_weight(self):
+ from mlx_lm.models import qwen2
+
+ weights = {
+ "model.embed_tokens.weight": "some_value",
+ "lm_head.weight": "existing_value",
+ }
+ args = qwen2.ModelArgs(
+ model_type="qwen2",
+ hidden_size=1024,
+ num_hidden_layers=4,
+ intermediate_size=2048,
+ num_attention_heads=4,
+ rms_norm_eps=1e-5,
+ vocab_size=10_000,
+ tie_word_embeddings=True,
+ )
+ model = qwen2.Model(args)
+ sanitized_weights = model.sanitize(weights)
+ self.assertIn("lm_head.weight", sanitized_weights)
+ self.assertEqual(sanitized_weights["lm_head.weight"], "existing_value")
+
def test_qwen(self):
from mlx_lm.models import qwen
@@ -194,6 +235,47 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
+ def test_starcoder2_tie_word_embeddings_without_lm_head_weight(self):
+ from mlx_lm.models import starcoder2
+
+ args = starcoder2.ModelArgs(
+ model_type="starcoder2",
+ hidden_size=1024,
+ num_hidden_layers=4,
+ intermediate_size=2048,
+ num_attention_heads=4,
+ num_key_value_heads=4,
+ tie_word_embeddings=True,
+ )
+ model = starcoder2.Model(args)
+ weights = {"model.embed_tokens.weight": "some_value"}
+ sanitized_weights = model.sanitize(weights)
+ self.assertIn("lm_head.weight", sanitized_weights)
+ self.assertEqual(sanitized_weights["lm_head.weight"], "some_value")
+
+ def test_starcoder2_tie_word_embeddings_with_lm_head_weight(self):
+
+ from mlx_lm.models import starcoder2
+
+ args = starcoder2.ModelArgs(
+ model_type="starcoder2",
+ hidden_size=1024,
+ num_hidden_layers=4,
+ intermediate_size=2048,
+ num_attention_heads=4,
+ num_key_value_heads=4,
+ tie_word_embeddings=True,
+ )
+ model = starcoder2.Model(args)
+ weights = {
+ "model.embed_tokens.weight": "some_value",
+ "lm_head.weight": "existing_value",
+ }
+
+ sanitized_weights = model.sanitize(weights)
+ self.assertIn("lm_head.weight", sanitized_weights)
+ self.assertEqual(sanitized_weights["lm_head.weight"], "existing_value")
+
if __name__ == "__main__":
unittest.main()