mlx-examples/esm/README.md
Vincent Amato 8e293bbc51 Add ESM
2025-08-16 15:59:51 -04:00

5.3 KiB
Raw Blame History

ESM-2

This repository provides an implementation of Meta's ESM-2 protein language model in MLX.1 ESM-2 is Metas second-generation Evolutionary Scale Model, a transformer-based protein language model trained on millions of diverse protein sequences with a masked language modeling objective.

Example contact prediction map

Example contact prediction map for a universal stress protein. In this case, ESM-2 650M achieves 86.4% precision at long-range contacts.

Setup

Install the requirements:

pip install -r requirements.txt

Usage

Below are the available ESM-2 models:

Model Parameters Layers
esm2_t6_8M_UR50D 8M 6
esm2_t12_35M_UR50D 35M 12
esm2_t30_150M_UR50D 150M 30
esm2_t33_650M_UR50D 650M 33
esm2_t36_3B_UR50D 3B 36
esm2_t48_15B_UR50D 15B 48

Convert a model to MLX format:

python convert.py --hf-path facebook/esm2_t33_650M_UR50D

This will save the converted model in a checkpoints directory.

Basic Inference

from esm import ESM2

# Load model and tokenizer
tokenizer, model = ESM2.from_pretrained("checkpoints/mlx-esm2_t33_650M_UR50D")

# Example protein sequence (human insulin)
sequence = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"

# Tokenize and run inference
tokens = tokenizer.encode(sequence)
result = model(tokens)
logits = result["logits"]  # Shape: (batch, length, vocab_size)

Masked Language Modeling

# For a complete example, see main.py
python main.py --sequence "YOUR_SEQUENCE" --mask-position 50

Embeddings

# Get sequence-level representations
seq_repr = model.get_sequence_representations(tokens, layer=-1)  # Shape: (batch, embed_dim)

# Extract per-residue representations from specific layers
representations = model.extract_features(tokens, repr_layers=[20, 30, 33])
final_layer = representations[33]  # Shape: (batch, length, embed_dim)

Contact Prediction

# Predict residue-residue contacts
contacts = model.predict_contacts(tokens)  # Shape: (batch, length, length)

# Or compute contacts together with logits, representations, etc.
outputs = model(tokens, return_contacts=True)
contacts = outputs["contacts"]

Examples

Mutation Effect Prediction: notebooks/mutation_effect_prediction.ipynb

This notebook demonstrates how to use ESM-2 for zero-shot mutation effect prediction by scoring amino acid substitutions based on their likelihood under the model. We validate the approach using experimental fitness data from β-lactamase TEM, showing how ESM-2 captures functional constraints without requiring structural information.

Embeddings: notebooks/embeddings.ipynb

This notebook explores how ESM-2 generates meaningful protein embeddings that capture evolutionary and functional relationships between proteins. We analyze six diverse human proteins to demonstrate how the learned representations cluster proteins by function and reveal biological similarities.

Contact Prediction: notebooks/contact_prediction.ipynb

This notebook shows how to predict residue-residue contacts in protein structures using ESM-2's attention patterns. We evaluate contact prediction performance on three diverse proteins, demonstrating how the model captures both local and long-range structural relationships directly from sequence data.

Benchmarking

Benchmark MLX performance:

python benchmarks/benchmark_mx.py

Benchmark PyTorch MPS performance:

python benchmarks/benchmark_pt.py

Expected performance on M4 MacBook Pro (ESM-2 650M, batch_size = 5):

  • MLX: 299 ms per step, 16.71 sequences/sec
  • PyTorch MPS: 402 ms per step, 12.43 sequences/sec

Testing

Verify correctness against original implementation:

python test.py

This tests tokenizer and model outputs (logits, hidden states, and attentions) for equivalence with the original implementation.

Citations:

@article{rives2019biological,
  author={Rives, Alexander and Meier, Joshua and Sercu, Tom and Goyal, Siddharth and Lin, Zeming and Liu, Jason and Guo, Demi and Ott, Myle and Zitnick, C. Lawrence and Ma, Jerry and Fergus, Rob},
  title={Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences},
  year={2019},
  doi={10.1101/622803},
  url={https://www.biorxiv.org/content/10.1101/622803v4},
  journal={PNAS}
}

@article{Lin2023,
  author={Zeming Lin et al.},
  title={Evolutionary-scale prediction of atomic-level protein structure with a language model},
  journal={Science},
  volume={379},
  pages={1123--1130},
  year={2023},
  doi={10.1126/science.ade2574},
  url={https://doi.org/10.1126/science.ade2574}
}

  1. Refer to the paper and code for more details. ↩︎