diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 817e5cc1..972f381e 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -179,8 +179,7 @@ class Model(nn.Module): out, cache = self.model(inputs, cache) return self.lm_head(out), cache - @staticmethod - def sanitize(weights): + def sanitize(self, weights): # Remove unused precomputed rotary freqs return { k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index 3773cbfb..1f196b0f 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -21,7 +21,7 @@ class ModelArgs(BaseModelArgs): rope_theta: float = 1000000 rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, str]]] = None - tie_word_embeddings: bool = False + tie_word_embeddings: bool = True def __post_init__(self): if self.num_key_value_heads is None: @@ -168,10 +168,10 @@ class Qwen2Model(nn.Module): class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() + self.args = args self.model_type = args.model_type self.model = Qwen2Model(args) - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__( self, @@ -179,14 +179,11 @@ class Model(nn.Module): cache=None, ): out, cache = self.model(inputs, cache) - if hasattr(self, "lm_head"): - return self.lm_head(out), cache + return self.lm_head(out), cache - out = out @ self.model.embed_tokens.weight.T - return out, cache - - @staticmethod - def sanitize(weights): + def sanitize(self, weights): + if self.args.tie_word_embeddings and "lm_head.weight" not in weights: + weights["lm_head.weight"] = weights["model.embed_tokens.weight"] # Remove unused precomputed rotary freqs return { k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index aefcf88c..6b10f716 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -147,11 +147,10 @@ class Starcoder2Model(nn.Module): class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() + self.args = args self.model_type = args.model_type self.model = Starcoder2Model(args) - # For 15B starcoder2 and fine-tuned models which don't tie word embeddings - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__( self, @@ -159,11 +158,12 @@ class Model(nn.Module): cache=None, ): out, cache = self.model(inputs, cache) - if not self.model.args.tie_word_embeddings: - return self.lm_head(out), cache - else: - out = out @ self.model.embed_tokens.weight.T - return out, cache + return self.lm_head(out), cache + + def sanitize(self, weights): + if self.args.tie_word_embeddings and "lm_head.weight" not in weights: + weights["lm_head.weight"] = weights["model.embed_tokens.weight"] + return weights @property def layers(self): diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 6adcf924..bfbe911b 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -319,12 +319,13 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module: weights.update(mx.load(wf)) model_class, model_args_class = _get_classes(config=config) - if hasattr(model_class, "sanitize"): - weights = model_class.sanitize(weights) model_args = model_args_class.from_dict(config) model = model_class(model_args) + if hasattr(model, "sanitize"): + weights = model.sanitize(weights) + if quantization is not None: # for legacy models that don't have lm_head quant due to non-32 dims if "lm_head.scales" not in weights.keys(): diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 7dcc82b2..7b99d065 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -129,6 +129,47 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_qwen2_tie_word_embeddings_without_lm_head_weight(self): + from mlx_lm.models import qwen2 + + args = qwen2.ModelArgs( + model_type="qwen2", + hidden_size=1024, + num_hidden_layers=4, + intermediate_size=2048, + num_attention_heads=4, + rms_norm_eps=1e-5, + vocab_size=10_000, + tie_word_embeddings=True, + ) + model = qwen2.Model(args) + weights = {"model.embed_tokens.weight": "some_value"} + sanitized_weights = model.sanitize(weights) + self.assertIn("lm_head.weight", sanitized_weights) + self.assertEqual(sanitized_weights["lm_head.weight"], "some_value") + + def test_qwen2_tie_word_embeddings_with_lm_head_weight(self): + from mlx_lm.models import qwen2 + + weights = { + "model.embed_tokens.weight": "some_value", + "lm_head.weight": "existing_value", + } + args = qwen2.ModelArgs( + model_type="qwen2", + hidden_size=1024, + num_hidden_layers=4, + intermediate_size=2048, + num_attention_heads=4, + rms_norm_eps=1e-5, + vocab_size=10_000, + tie_word_embeddings=True, + ) + model = qwen2.Model(args) + sanitized_weights = model.sanitize(weights) + self.assertIn("lm_head.weight", sanitized_weights) + self.assertEqual(sanitized_weights["lm_head.weight"], "existing_value") + def test_qwen(self): from mlx_lm.models import qwen @@ -194,6 +235,47 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_starcoder2_tie_word_embeddings_without_lm_head_weight(self): + from mlx_lm.models import starcoder2 + + args = starcoder2.ModelArgs( + model_type="starcoder2", + hidden_size=1024, + num_hidden_layers=4, + intermediate_size=2048, + num_attention_heads=4, + num_key_value_heads=4, + tie_word_embeddings=True, + ) + model = starcoder2.Model(args) + weights = {"model.embed_tokens.weight": "some_value"} + sanitized_weights = model.sanitize(weights) + self.assertIn("lm_head.weight", sanitized_weights) + self.assertEqual(sanitized_weights["lm_head.weight"], "some_value") + + def test_starcoder2_tie_word_embeddings_with_lm_head_weight(self): + + from mlx_lm.models import starcoder2 + + args = starcoder2.ModelArgs( + model_type="starcoder2", + hidden_size=1024, + num_hidden_layers=4, + intermediate_size=2048, + num_attention_heads=4, + num_key_value_heads=4, + tie_word_embeddings=True, + ) + model = starcoder2.Model(args) + weights = { + "model.embed_tokens.weight": "some_value", + "lm_head.weight": "existing_value", + } + + sanitized_weights = model.sanitize(weights) + self.assertIn("lm_head.weight", sanitized_weights) + self.assertEqual(sanitized_weights["lm_head.weight"], "existing_value") + if __name__ == "__main__": unittest.main()