mlx-examples/esm/convert.py
Vincent Amato 8e293bbc51 Add ESM
2025-08-16 15:59:51 -04:00

178 lines
5.8 KiB
Python

import argparse
import json
import shutil
from pathlib import Path
from typing import Dict
import mlx.core as mx
import torch
from huggingface_hub import snapshot_download
def download(hf_repo: str) -> Path:
"""Download model from Hugging Face."""
return Path(
snapshot_download(
repo_id=hf_repo,
allow_patterns=["*.safetensors", "*.json", "*.bin", "*.txt"],
)
)
def remap_key(key: str) -> str:
"""Remap HuggingFace ESM key names to MLX format."""
# Skip position embeddings and position_ids
if "position_embeddings" in key or "position_ids" in key:
return None
# Map lm_head components properly
if key == "lm_head.decoder.weight":
return "lm_head.weight"
if key == "lm_head.decoder.bias":
return "lm_head.bias"
if key == "lm_head.dense.weight":
return "lm_head.dense.weight"
if key == "lm_head.dense.bias":
return "lm_head.dense.bias"
if key == "lm_head.layer_norm.weight":
return "lm_head.layer_norm.weight"
if key == "lm_head.layer_norm.bias":
return "lm_head.layer_norm.bias"
# Core remapping patterns
key = key.replace("esm.embeddings.word_embeddings", "embed_tokens")
key = key.replace("esm.encoder.emb_layer_norm_after", "emb_layer_norm_after")
key = key.replace("esm.encoder.layer.", "layer_")
key = key.replace("esm.contact_head", "contact_head")
key = key.replace("lm_head", "lm_head")
# Attention patterns
key = key.replace(".attention.self.", ".self_attn.")
key = key.replace(".attention.output.dense", ".self_attn.out_proj")
key = key.replace(".attention.LayerNorm", ".self_attn_layer_norm")
key = key.replace(".query", ".q_proj")
key = key.replace(".key", ".k_proj")
key = key.replace(".value", ".v_proj")
key = key.replace(".rotary_embeddings", ".rot_emb")
# FFN patterns
key = key.replace(".intermediate.dense", ".fc1")
key = key.replace(".output.dense", ".fc2")
key = key.replace(".LayerNorm", ".final_layer_norm")
return key
def load_weights(model_path: Path) -> Dict:
"""Load weights from safetensors or PyTorch bin files."""
# Check for safetensors file
safetensors_path = model_path / "model.safetensors"
if safetensors_path.exists():
print("Loading from safetensors...")
return mx.load(str(safetensors_path))
# Check for single bin file
single_bin_path = model_path / "pytorch_model.bin"
if single_bin_path.exists():
print("Loading from pytorch_model.bin...")
state_dict = torch.load(str(single_bin_path), map_location="cpu")
return {k: v.numpy() for k, v in state_dict.items()}
# Check for sharded bin files
index_file = model_path / "pytorch_model.bin.index.json"
if index_file.exists():
print("Loading from sharded bin files...")
with open(index_file) as f:
index = json.load(f)
# Get unique shard files
shard_files = set(index["weight_map"].values())
# Load all shards
state_dict = {}
for shard_file in sorted(shard_files):
print(f" Loading shard: {shard_file}")
shard_path = model_path / shard_file
shard_dict = torch.load(str(shard_path), map_location="cpu")
state_dict.update(shard_dict)
return {k: v.numpy() for k, v in state_dict.items()}
raise ValueError(f"No model weights found in {model_path}")
def convert(model_path: Path) -> Dict[str, mx.array]:
"""Convert ESM weights to MLX format."""
# Load weights from any format
weights = load_weights(model_path)
# Convert keys and create MLX arrays
mlx_weights = {}
for key, value in weights.items():
mlx_key = remap_key(key)
if mlx_key is not None:
mlx_weights[mlx_key] = (
mx.array(value) if not isinstance(value, mx.array) else value
)
# If lm_head.weight is missing but embed_tokens.weight exists, set up weight sharing
# (This is for smaller models that don't have a separate lm_head.decoder.weight)
if "lm_head.weight" not in mlx_weights and "embed_tokens.weight" in mlx_weights:
mlx_weights["lm_head.weight"] = mlx_weights["embed_tokens.weight"]
return mlx_weights
def main():
parser = argparse.ArgumentParser(description="Convert ESM weights to MLX format")
parser.add_argument(
"--hf-path", default="facebook/esm2_t6_8M_UR50D", help="Hugging Face model path"
)
parser.add_argument("--mlx-path", default=None, help="Output path for MLX model")
parser.add_argument(
"--checkpoints-dir",
default="checkpoints",
help="Directory to save checkpoints (default: checkpoints)",
)
args = parser.parse_args()
# Download model
print(f"Downloading {args.hf_path}...")
model_path = download(args.hf_path)
# Set output path
if args.mlx_path is None:
model_name = args.hf_path.split("/")[-1]
checkpoints_dir = Path(args.checkpoints_dir)
checkpoints_dir.mkdir(parents=True, exist_ok=True)
args.mlx_path = checkpoints_dir / f"mlx-{model_name}"
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
# Convert weights
print("Converting weights...")
mlx_weights = convert(model_path)
# Save weights
print(f"Saving MLX weights to {mlx_path}...")
mx.save_safetensors(str(mlx_path / "model.safetensors"), mlx_weights)
# Copy config and other files
print("Copying config...")
shutil.copy(model_path / "config.json", mlx_path / "config.json")
for file_name in ["special_tokens_map.json", "tokenizer.json", "vocab.txt"]:
src_file = model_path / file_name
if src_file.exists():
shutil.copy(src_file, mlx_path / file_name)
print(f"Conversion complete! MLX model saved to {mlx_path}")
if __name__ == "__main__":
main()