mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-20 18:26:39 +08:00
178 lines
5.8 KiB
Python
178 lines
5.8 KiB
Python
import argparse
|
|
import json
|
|
import shutil
|
|
from pathlib import Path
|
|
from typing import Dict
|
|
|
|
import mlx.core as mx
|
|
import torch
|
|
from huggingface_hub import snapshot_download
|
|
|
|
|
|
def download(hf_repo: str) -> Path:
|
|
"""Download model from Hugging Face."""
|
|
return Path(
|
|
snapshot_download(
|
|
repo_id=hf_repo,
|
|
allow_patterns=["*.safetensors", "*.json", "*.bin", "*.txt"],
|
|
)
|
|
)
|
|
|
|
|
|
def remap_key(key: str) -> str:
|
|
"""Remap HuggingFace ESM key names to MLX format."""
|
|
|
|
# Skip position embeddings and position_ids
|
|
if "position_embeddings" in key or "position_ids" in key:
|
|
return None
|
|
|
|
# Map lm_head components properly
|
|
if key == "lm_head.decoder.weight":
|
|
return "lm_head.weight"
|
|
if key == "lm_head.decoder.bias":
|
|
return "lm_head.bias"
|
|
if key == "lm_head.dense.weight":
|
|
return "lm_head.dense.weight"
|
|
if key == "lm_head.dense.bias":
|
|
return "lm_head.dense.bias"
|
|
if key == "lm_head.layer_norm.weight":
|
|
return "lm_head.layer_norm.weight"
|
|
if key == "lm_head.layer_norm.bias":
|
|
return "lm_head.layer_norm.bias"
|
|
|
|
# Core remapping patterns
|
|
key = key.replace("esm.embeddings.word_embeddings", "embed_tokens")
|
|
key = key.replace("esm.encoder.emb_layer_norm_after", "emb_layer_norm_after")
|
|
key = key.replace("esm.encoder.layer.", "layer_")
|
|
key = key.replace("esm.contact_head", "contact_head")
|
|
key = key.replace("lm_head", "lm_head")
|
|
|
|
# Attention patterns
|
|
key = key.replace(".attention.self.", ".self_attn.")
|
|
key = key.replace(".attention.output.dense", ".self_attn.out_proj")
|
|
key = key.replace(".attention.LayerNorm", ".self_attn_layer_norm")
|
|
key = key.replace(".query", ".q_proj")
|
|
key = key.replace(".key", ".k_proj")
|
|
key = key.replace(".value", ".v_proj")
|
|
key = key.replace(".rotary_embeddings", ".rot_emb")
|
|
|
|
# FFN patterns
|
|
key = key.replace(".intermediate.dense", ".fc1")
|
|
key = key.replace(".output.dense", ".fc2")
|
|
key = key.replace(".LayerNorm", ".final_layer_norm")
|
|
|
|
return key
|
|
|
|
|
|
def load_weights(model_path: Path) -> Dict:
|
|
"""Load weights from safetensors or PyTorch bin files."""
|
|
|
|
# Check for safetensors file
|
|
safetensors_path = model_path / "model.safetensors"
|
|
if safetensors_path.exists():
|
|
print("Loading from safetensors...")
|
|
return mx.load(str(safetensors_path))
|
|
|
|
# Check for single bin file
|
|
single_bin_path = model_path / "pytorch_model.bin"
|
|
if single_bin_path.exists():
|
|
print("Loading from pytorch_model.bin...")
|
|
state_dict = torch.load(str(single_bin_path), map_location="cpu")
|
|
return {k: v.numpy() for k, v in state_dict.items()}
|
|
|
|
# Check for sharded bin files
|
|
index_file = model_path / "pytorch_model.bin.index.json"
|
|
if index_file.exists():
|
|
print("Loading from sharded bin files...")
|
|
with open(index_file) as f:
|
|
index = json.load(f)
|
|
|
|
# Get unique shard files
|
|
shard_files = set(index["weight_map"].values())
|
|
|
|
# Load all shards
|
|
state_dict = {}
|
|
for shard_file in sorted(shard_files):
|
|
print(f" Loading shard: {shard_file}")
|
|
shard_path = model_path / shard_file
|
|
shard_dict = torch.load(str(shard_path), map_location="cpu")
|
|
state_dict.update(shard_dict)
|
|
|
|
return {k: v.numpy() for k, v in state_dict.items()}
|
|
|
|
raise ValueError(f"No model weights found in {model_path}")
|
|
|
|
|
|
def convert(model_path: Path) -> Dict[str, mx.array]:
|
|
"""Convert ESM weights to MLX format."""
|
|
|
|
# Load weights from any format
|
|
weights = load_weights(model_path)
|
|
|
|
# Convert keys and create MLX arrays
|
|
mlx_weights = {}
|
|
for key, value in weights.items():
|
|
mlx_key = remap_key(key)
|
|
if mlx_key is not None:
|
|
mlx_weights[mlx_key] = (
|
|
mx.array(value) if not isinstance(value, mx.array) else value
|
|
)
|
|
|
|
# If lm_head.weight is missing but embed_tokens.weight exists, set up weight sharing
|
|
# (This is for smaller models that don't have a separate lm_head.decoder.weight)
|
|
if "lm_head.weight" not in mlx_weights and "embed_tokens.weight" in mlx_weights:
|
|
mlx_weights["lm_head.weight"] = mlx_weights["embed_tokens.weight"]
|
|
|
|
return mlx_weights
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Convert ESM weights to MLX format")
|
|
parser.add_argument(
|
|
"--hf-path", default="facebook/esm2_t6_8M_UR50D", help="Hugging Face model path"
|
|
)
|
|
parser.add_argument("--mlx-path", default=None, help="Output path for MLX model")
|
|
parser.add_argument(
|
|
"--checkpoints-dir",
|
|
default="checkpoints",
|
|
help="Directory to save checkpoints (default: checkpoints)",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Download model
|
|
print(f"Downloading {args.hf_path}...")
|
|
model_path = download(args.hf_path)
|
|
|
|
# Set output path
|
|
if args.mlx_path is None:
|
|
model_name = args.hf_path.split("/")[-1]
|
|
checkpoints_dir = Path(args.checkpoints_dir)
|
|
checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
|
args.mlx_path = checkpoints_dir / f"mlx-{model_name}"
|
|
mlx_path = Path(args.mlx_path)
|
|
mlx_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Convert weights
|
|
print("Converting weights...")
|
|
mlx_weights = convert(model_path)
|
|
|
|
# Save weights
|
|
print(f"Saving MLX weights to {mlx_path}...")
|
|
mx.save_safetensors(str(mlx_path / "model.safetensors"), mlx_weights)
|
|
|
|
# Copy config and other files
|
|
print("Copying config...")
|
|
shutil.copy(model_path / "config.json", mlx_path / "config.json")
|
|
|
|
for file_name in ["special_tokens_map.json", "tokenizer.json", "vocab.txt"]:
|
|
src_file = model_path / file_name
|
|
if src_file.exists():
|
|
shutil.copy(src_file, mlx_path / file_name)
|
|
|
|
print(f"Conversion complete! MLX model saved to {mlx_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|