mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
chore(mlx-lm): fix tie_word_embeddings for qwen2 (#566)
* chore: fix tie_word_embeddings for qwen2 * chore: default tie_word_embeddings to True
This commit is contained in:
parent
39084e81c2
commit
3535408c99
@ -179,8 +179,7 @@ class Model(nn.Module):
|
|||||||
out, cache = self.model(inputs, cache)
|
out, cache = self.model(inputs, cache)
|
||||||
return self.lm_head(out), cache
|
return self.lm_head(out), cache
|
||||||
|
|
||||||
@staticmethod
|
def sanitize(self, weights):
|
||||||
def sanitize(weights):
|
|
||||||
# Remove unused precomputed rotary freqs
|
# Remove unused precomputed rotary freqs
|
||||||
return {
|
return {
|
||||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||||
|
@ -21,7 +21,7 @@ class ModelArgs(BaseModelArgs):
|
|||||||
rope_theta: float = 1000000
|
rope_theta: float = 1000000
|
||||||
rope_traditional: bool = False
|
rope_traditional: bool = False
|
||||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||||
tie_word_embeddings: bool = False
|
tie_word_embeddings: bool = True
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.num_key_value_heads is None:
|
if self.num_key_value_heads is None:
|
||||||
@ -168,10 +168,10 @@ class Qwen2Model(nn.Module):
|
|||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
self.model_type = args.model_type
|
self.model_type = args.model_type
|
||||||
self.model = Qwen2Model(args)
|
self.model = Qwen2Model(args)
|
||||||
if not args.tie_word_embeddings:
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -179,14 +179,11 @@ class Model(nn.Module):
|
|||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out, cache = self.model(inputs, cache)
|
out, cache = self.model(inputs, cache)
|
||||||
if hasattr(self, "lm_head"):
|
return self.lm_head(out), cache
|
||||||
return self.lm_head(out), cache
|
|
||||||
|
|
||||||
out = out @ self.model.embed_tokens.weight.T
|
def sanitize(self, weights):
|
||||||
return out, cache
|
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
|
||||||
|
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
|
||||||
@staticmethod
|
|
||||||
def sanitize(weights):
|
|
||||||
# Remove unused precomputed rotary freqs
|
# Remove unused precomputed rotary freqs
|
||||||
return {
|
return {
|
||||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||||
|
@ -147,11 +147,10 @@ class Starcoder2Model(nn.Module):
|
|||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
self.model_type = args.model_type
|
self.model_type = args.model_type
|
||||||
self.model = Starcoder2Model(args)
|
self.model = Starcoder2Model(args)
|
||||||
# For 15B starcoder2 and fine-tuned models which don't tie word embeddings
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
if not args.tie_word_embeddings:
|
|
||||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -159,11 +158,12 @@ class Model(nn.Module):
|
|||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out, cache = self.model(inputs, cache)
|
out, cache = self.model(inputs, cache)
|
||||||
if not self.model.args.tie_word_embeddings:
|
return self.lm_head(out), cache
|
||||||
return self.lm_head(out), cache
|
|
||||||
else:
|
def sanitize(self, weights):
|
||||||
out = out @ self.model.embed_tokens.weight.T
|
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
|
||||||
return out, cache
|
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
|
||||||
|
return weights
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
|
@ -319,12 +319,13 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
|||||||
weights.update(mx.load(wf))
|
weights.update(mx.load(wf))
|
||||||
|
|
||||||
model_class, model_args_class = _get_classes(config=config)
|
model_class, model_args_class = _get_classes(config=config)
|
||||||
if hasattr(model_class, "sanitize"):
|
|
||||||
weights = model_class.sanitize(weights)
|
|
||||||
|
|
||||||
model_args = model_args_class.from_dict(config)
|
model_args = model_args_class.from_dict(config)
|
||||||
model = model_class(model_args)
|
model = model_class(model_args)
|
||||||
|
|
||||||
|
if hasattr(model, "sanitize"):
|
||||||
|
weights = model.sanitize(weights)
|
||||||
|
|
||||||
if quantization is not None:
|
if quantization is not None:
|
||||||
# for legacy models that don't have lm_head quant due to non-32 dims
|
# for legacy models that don't have lm_head quant due to non-32 dims
|
||||||
if "lm_head.scales" not in weights.keys():
|
if "lm_head.scales" not in weights.keys():
|
||||||
|
@ -129,6 +129,47 @@ class TestModels(unittest.TestCase):
|
|||||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_qwen2_tie_word_embeddings_without_lm_head_weight(self):
|
||||||
|
from mlx_lm.models import qwen2
|
||||||
|
|
||||||
|
args = qwen2.ModelArgs(
|
||||||
|
model_type="qwen2",
|
||||||
|
hidden_size=1024,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
intermediate_size=2048,
|
||||||
|
num_attention_heads=4,
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
vocab_size=10_000,
|
||||||
|
tie_word_embeddings=True,
|
||||||
|
)
|
||||||
|
model = qwen2.Model(args)
|
||||||
|
weights = {"model.embed_tokens.weight": "some_value"}
|
||||||
|
sanitized_weights = model.sanitize(weights)
|
||||||
|
self.assertIn("lm_head.weight", sanitized_weights)
|
||||||
|
self.assertEqual(sanitized_weights["lm_head.weight"], "some_value")
|
||||||
|
|
||||||
|
def test_qwen2_tie_word_embeddings_with_lm_head_weight(self):
|
||||||
|
from mlx_lm.models import qwen2
|
||||||
|
|
||||||
|
weights = {
|
||||||
|
"model.embed_tokens.weight": "some_value",
|
||||||
|
"lm_head.weight": "existing_value",
|
||||||
|
}
|
||||||
|
args = qwen2.ModelArgs(
|
||||||
|
model_type="qwen2",
|
||||||
|
hidden_size=1024,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
intermediate_size=2048,
|
||||||
|
num_attention_heads=4,
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
vocab_size=10_000,
|
||||||
|
tie_word_embeddings=True,
|
||||||
|
)
|
||||||
|
model = qwen2.Model(args)
|
||||||
|
sanitized_weights = model.sanitize(weights)
|
||||||
|
self.assertIn("lm_head.weight", sanitized_weights)
|
||||||
|
self.assertEqual(sanitized_weights["lm_head.weight"], "existing_value")
|
||||||
|
|
||||||
def test_qwen(self):
|
def test_qwen(self):
|
||||||
from mlx_lm.models import qwen
|
from mlx_lm.models import qwen
|
||||||
|
|
||||||
@ -194,6 +235,47 @@ class TestModels(unittest.TestCase):
|
|||||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_starcoder2_tie_word_embeddings_without_lm_head_weight(self):
|
||||||
|
from mlx_lm.models import starcoder2
|
||||||
|
|
||||||
|
args = starcoder2.ModelArgs(
|
||||||
|
model_type="starcoder2",
|
||||||
|
hidden_size=1024,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
intermediate_size=2048,
|
||||||
|
num_attention_heads=4,
|
||||||
|
num_key_value_heads=4,
|
||||||
|
tie_word_embeddings=True,
|
||||||
|
)
|
||||||
|
model = starcoder2.Model(args)
|
||||||
|
weights = {"model.embed_tokens.weight": "some_value"}
|
||||||
|
sanitized_weights = model.sanitize(weights)
|
||||||
|
self.assertIn("lm_head.weight", sanitized_weights)
|
||||||
|
self.assertEqual(sanitized_weights["lm_head.weight"], "some_value")
|
||||||
|
|
||||||
|
def test_starcoder2_tie_word_embeddings_with_lm_head_weight(self):
|
||||||
|
|
||||||
|
from mlx_lm.models import starcoder2
|
||||||
|
|
||||||
|
args = starcoder2.ModelArgs(
|
||||||
|
model_type="starcoder2",
|
||||||
|
hidden_size=1024,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
intermediate_size=2048,
|
||||||
|
num_attention_heads=4,
|
||||||
|
num_key_value_heads=4,
|
||||||
|
tie_word_embeddings=True,
|
||||||
|
)
|
||||||
|
model = starcoder2.Model(args)
|
||||||
|
weights = {
|
||||||
|
"model.embed_tokens.weight": "some_value",
|
||||||
|
"lm_head.weight": "existing_value",
|
||||||
|
}
|
||||||
|
|
||||||
|
sanitized_weights = model.sanitize(weights)
|
||||||
|
self.assertIn("lm_head.weight", sanitized_weights)
|
||||||
|
self.assertEqual(sanitized_weights["lm_head.weight"], "existing_value")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user