chore(mlx-lm): fix tie_word_embeddings for qwen2 (#566)

* chore: fix tie_word_embeddings for qwen2

* chore: default tie_word_embeddings to True
This commit is contained in:
Anchen
2024-03-13 15:34:32 +11:00
committed by GitHub
parent 39084e81c2
commit 3535408c99
5 changed files with 101 additions and 22 deletions

View File

@@ -319,12 +319,13 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
weights.update(mx.load(wf))
model_class, model_args_class = _get_classes(config=config)
if hasattr(model_class, "sanitize"):
weights = model_class.sanitize(weights)
model_args = model_args_class.from_dict(config)
model = model_class(model_args)
if hasattr(model, "sanitize"):
weights = model.sanitize(weights)
if quantization is not None:
# for legacy models that don't have lm_head quant due to non-32 dims
if "lm_head.scales" not in weights.keys():