mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-21 20:46:50 +08:00
Add ESM
This commit is contained in:
parent
4b2a0df237
commit
8e293bbc51
156
esm/README.md
Normal file
156
esm/README.md
Normal file
@ -0,0 +1,156 @@
|
||||
# ESM-2
|
||||
|
||||
This repository provides an implementation of Meta's ESM-2 protein language model
|
||||
in MLX.[^1] ESM-2 is Meta’s second-generation Evolutionary Scale Model, a
|
||||
transformer-based protein language model trained on millions of diverse protein
|
||||
sequences with a masked language modeling objective.
|
||||
|
||||

|
||||
|
||||
_Example contact prediction map for a universal stress protein. In this case, ESM-2 650M achieves 86.4% precision at long-range contacts._
|
||||
|
||||
## Setup
|
||||
|
||||
Install the requirements:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
Below are the available ESM-2 models:
|
||||
| Model | Parameters | Layers |
|
||||
|-------|------------|--------|
|
||||
| [`esm2_t6_8M_UR50D`](https://huggingface.co/facebook/esm2_t6_8M_UR50D) | 8M | 6 |
|
||||
| [`esm2_t12_35M_UR50D`](https://huggingface.co/facebook/esm2_t12_35M_UR50D) | 35M | 12 |
|
||||
| [`esm2_t30_150M_UR50D`](https://huggingface.co/facebook/esm2_t30_150M_UR50D) | 150M | 30 |
|
||||
| [`esm2_t33_650M_UR50D`](https://huggingface.co/facebook/esm2_t33_650M_UR50D) | 650M | 33 |
|
||||
| [`esm2_t36_3B_UR50D`](https://huggingface.co/facebook/esm2_t36_3B_UR50D) | 3B | 36 |
|
||||
| [`esm2_t48_15B_UR50D`](https://huggingface.co/facebook/esm2_t48_15B_UR50D) | 15B | 48 |
|
||||
|
||||
Convert a model to MLX format:
|
||||
|
||||
```bash
|
||||
python convert.py --hf-path facebook/esm2_t33_650M_UR50D
|
||||
```
|
||||
|
||||
This will save the converted model in a checkpoints directory.
|
||||
|
||||
### Basic Inference
|
||||
|
||||
```python
|
||||
from esm import ESM2
|
||||
|
||||
# Load model and tokenizer
|
||||
tokenizer, model = ESM2.from_pretrained("checkpoints/mlx-esm2_t33_650M_UR50D")
|
||||
|
||||
# Example protein sequence (human insulin)
|
||||
sequence = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"
|
||||
|
||||
# Tokenize and run inference
|
||||
tokens = tokenizer.encode(sequence)
|
||||
result = model(tokens)
|
||||
logits = result["logits"] # Shape: (batch, length, vocab_size)
|
||||
```
|
||||
|
||||
### Masked Language Modeling
|
||||
|
||||
```bash
|
||||
# For a complete example, see main.py
|
||||
python main.py --sequence "YOUR_SEQUENCE" --mask-position 50
|
||||
```
|
||||
|
||||
### Embeddings
|
||||
|
||||
```python
|
||||
# Get sequence-level representations
|
||||
seq_repr = model.get_sequence_representations(tokens, layer=-1) # Shape: (batch, embed_dim)
|
||||
|
||||
# Extract per-residue representations from specific layers
|
||||
representations = model.extract_features(tokens, repr_layers=[20, 30, 33])
|
||||
final_layer = representations[33] # Shape: (batch, length, embed_dim)
|
||||
```
|
||||
|
||||
### Contact Prediction
|
||||
|
||||
```python
|
||||
# Predict residue-residue contacts
|
||||
contacts = model.predict_contacts(tokens) # Shape: (batch, length, length)
|
||||
|
||||
# Or compute contacts together with logits, representations, etc.
|
||||
outputs = model(tokens, return_contacts=True)
|
||||
contacts = outputs["contacts"]
|
||||
```
|
||||
|
||||
### Examples
|
||||
|
||||
**Mutation Effect Prediction**: [notebooks/mutation_effect_prediction.ipynb](notebooks/mutation_effect_prediction.ipynb)
|
||||
|
||||
This notebook demonstrates how to use ESM-2 for zero-shot mutation effect prediction by scoring amino acid substitutions based on their likelihood under the model. We validate the approach using experimental fitness data from β-lactamase TEM, showing how ESM-2 captures functional constraints without requiring structural information.
|
||||
|
||||
**Embeddings**: [notebooks/embeddings.ipynb](notebooks/embeddings.ipynb)
|
||||
|
||||
This notebook explores how ESM-2 generates meaningful protein embeddings that capture evolutionary and functional relationships between proteins. We analyze six diverse human proteins to demonstrate how the learned representations cluster proteins by function and reveal biological similarities.
|
||||
|
||||
**Contact Prediction**: [notebooks/contact_prediction.ipynb](notebooks/contact_prediction.ipynb)
|
||||
|
||||
This notebook shows how to predict residue-residue contacts in protein structures using ESM-2's attention patterns. We evaluate contact prediction performance on three diverse proteins, demonstrating how the model captures both local and long-range structural relationships directly from sequence data.
|
||||
|
||||
### Benchmarking
|
||||
|
||||
Benchmark MLX performance:
|
||||
|
||||
```bash
|
||||
python benchmarks/benchmark_mx.py
|
||||
```
|
||||
|
||||
Benchmark PyTorch MPS performance:
|
||||
|
||||
```bash
|
||||
python benchmarks/benchmark_pt.py
|
||||
```
|
||||
|
||||
Expected performance on M4 MacBook Pro (ESM-2 650M, batch_size = 5):
|
||||
|
||||
- MLX: 299 ms per step, 16.71 sequences/sec
|
||||
- PyTorch MPS: 402 ms per step, 12.43 sequences/sec
|
||||
|
||||
### Testing
|
||||
|
||||
Verify correctness against original implementation:
|
||||
|
||||
```bash
|
||||
python test.py
|
||||
```
|
||||
|
||||
This tests tokenizer and model outputs (logits, hidden states, and attentions) for equivalence with the original implementation.
|
||||
|
||||
### Citations:
|
||||
|
||||
```bibtex
|
||||
@article{rives2019biological,
|
||||
author={Rives, Alexander and Meier, Joshua and Sercu, Tom and Goyal, Siddharth and Lin, Zeming and Liu, Jason and Guo, Demi and Ott, Myle and Zitnick, C. Lawrence and Ma, Jerry and Fergus, Rob},
|
||||
title={Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences},
|
||||
year={2019},
|
||||
doi={10.1101/622803},
|
||||
url={https://www.biorxiv.org/content/10.1101/622803v4},
|
||||
journal={PNAS}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Lin2023,
|
||||
author={Zeming Lin et al.},
|
||||
title={Evolutionary-scale prediction of atomic-level protein structure with a language model},
|
||||
journal={Science},
|
||||
volume={379},
|
||||
pages={1123--1130},
|
||||
year={2023},
|
||||
doi={10.1126/science.ade2574},
|
||||
url={https://doi.org/10.1126/science.ade2574}
|
||||
}
|
||||
```
|
||||
|
||||
[^1]: Refer to the [paper](https://www.science.org/doi/10.1126/science.ade2574) and [code](https://github.com/facebookresearch/esm) for more details.
|
BIN
esm/assets/contact_prediction.png
Normal file
BIN
esm/assets/contact_prediction.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 34 KiB |
47
esm/benchmarks/benchmark_mx.py
Normal file
47
esm/benchmarks/benchmark_mx.py
Normal file
@ -0,0 +1,47 @@
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
# Add parent directory to Python path
|
||||
cur_path = Path(__file__).parents[1].resolve()
|
||||
sys.path.append(str(cur_path))
|
||||
|
||||
from esm import ESM2
|
||||
|
||||
# Example protein sequence (Green Fluorescent Protein)
|
||||
protein_sequence = "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK"
|
||||
|
||||
# Load pretrained ESM-2 model and its tokenizer from local checkpoint
|
||||
tokenizer, model = ESM2.from_pretrained("checkpoints/mlx-esm2_t33_650M_UR50D")
|
||||
|
||||
# Number of sequences to process in each forward pass
|
||||
batch_size = 5
|
||||
|
||||
# Number of timing iterations for performance measurement
|
||||
steps = 50
|
||||
|
||||
# Tokenize the protein sequence into integer IDs for the model
|
||||
# Replicate the same sequence 'batch_size' times to create a batch
|
||||
tokens = tokenizer.batch_encode([protein_sequence] * batch_size)
|
||||
|
||||
# Warm-up phase
|
||||
for _ in range(10):
|
||||
result = model(tokens)
|
||||
mx.eval(result["logits"]) # Force computation to complete
|
||||
|
||||
# Measure average inference time over 'steps' iterations
|
||||
tic = time.time()
|
||||
for _ in range(steps):
|
||||
result = model(tokens)
|
||||
mx.eval(result["logits"]) # Synchronize and ensure computation finishes
|
||||
toc = time.time()
|
||||
|
||||
# Compute metrics: average time per step (ms) and throughput (sequences/sec)
|
||||
ms_per_step = 1000 * (toc - tic) / steps
|
||||
throughput = batch_size * 1000 / ms_per_step
|
||||
|
||||
# Display results
|
||||
print(f"Time (ms) per step: {ms_per_step:.3f}")
|
||||
print(f"Throughput: {throughput:.2f} sequences/sec")
|
52
esm/benchmarks/benchmark_pt.py
Normal file
52
esm/benchmarks/benchmark_pt.py
Normal file
@ -0,0 +1,52 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, EsmForMaskedLM
|
||||
|
||||
# Example protein sequence (Green Fluorescent Protein)
|
||||
protein_sequence = "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK"
|
||||
|
||||
# Hugging Face model identifier for ESM-2 (33 layers, 650M params, UR50D training set)
|
||||
model_name = "facebook/esm2_t33_650M_UR50D"
|
||||
|
||||
# Load tokenizer and model; move model to Apple Metal Performance Shaders (MPS) device
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = EsmForMaskedLM.from_pretrained(model_name).to("mps")
|
||||
|
||||
# Number of sequences per forward pass
|
||||
batch_size = 5
|
||||
|
||||
# Number of timing iterations
|
||||
steps = 50
|
||||
|
||||
# Tokenize input sequence and replicate for the batch
|
||||
# Replicate the same sequence 'batch_size' times to create a batch
|
||||
inputs = tokenizer(
|
||||
[protein_sequence] * batch_size,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=1024,
|
||||
)
|
||||
input_ids = inputs["input_ids"].to("mps")
|
||||
attention_mask = inputs["attention_mask"].to("mps")
|
||||
|
||||
# Warm-up phase
|
||||
for _ in range(10):
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
torch.mps.synchronize() # Ensure all queued ops on MPS are complete before next step
|
||||
|
||||
# Timed inference loop
|
||||
tic = time.time()
|
||||
for _ in range(steps):
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
torch.mps.synchronize() # Wait for computation to finish before timing next iteration
|
||||
toc = time.time()
|
||||
|
||||
# Compute performance metrics
|
||||
ms_per_step = 1000 * (toc - tic) / steps
|
||||
throughput = batch_size * 1000 / ms_per_step
|
||||
|
||||
# Report results
|
||||
print(f"Time (ms) per step: {ms_per_step:.3f}")
|
||||
print(f"Throughput: {throughput:.2f} sequences/sec")
|
177
esm/convert.py
Normal file
177
esm/convert.py
Normal file
@ -0,0 +1,177 @@
|
||||
import argparse
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
def download(hf_repo: str) -> Path:
|
||||
"""Download model from Hugging Face."""
|
||||
return Path(
|
||||
snapshot_download(
|
||||
repo_id=hf_repo,
|
||||
allow_patterns=["*.safetensors", "*.json", "*.bin", "*.txt"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def remap_key(key: str) -> str:
|
||||
"""Remap HuggingFace ESM key names to MLX format."""
|
||||
|
||||
# Skip position embeddings and position_ids
|
||||
if "position_embeddings" in key or "position_ids" in key:
|
||||
return None
|
||||
|
||||
# Map lm_head components properly
|
||||
if key == "lm_head.decoder.weight":
|
||||
return "lm_head.weight"
|
||||
if key == "lm_head.decoder.bias":
|
||||
return "lm_head.bias"
|
||||
if key == "lm_head.dense.weight":
|
||||
return "lm_head.dense.weight"
|
||||
if key == "lm_head.dense.bias":
|
||||
return "lm_head.dense.bias"
|
||||
if key == "lm_head.layer_norm.weight":
|
||||
return "lm_head.layer_norm.weight"
|
||||
if key == "lm_head.layer_norm.bias":
|
||||
return "lm_head.layer_norm.bias"
|
||||
|
||||
# Core remapping patterns
|
||||
key = key.replace("esm.embeddings.word_embeddings", "embed_tokens")
|
||||
key = key.replace("esm.encoder.emb_layer_norm_after", "emb_layer_norm_after")
|
||||
key = key.replace("esm.encoder.layer.", "layer_")
|
||||
key = key.replace("esm.contact_head", "contact_head")
|
||||
key = key.replace("lm_head", "lm_head")
|
||||
|
||||
# Attention patterns
|
||||
key = key.replace(".attention.self.", ".self_attn.")
|
||||
key = key.replace(".attention.output.dense", ".self_attn.out_proj")
|
||||
key = key.replace(".attention.LayerNorm", ".self_attn_layer_norm")
|
||||
key = key.replace(".query", ".q_proj")
|
||||
key = key.replace(".key", ".k_proj")
|
||||
key = key.replace(".value", ".v_proj")
|
||||
key = key.replace(".rotary_embeddings", ".rot_emb")
|
||||
|
||||
# FFN patterns
|
||||
key = key.replace(".intermediate.dense", ".fc1")
|
||||
key = key.replace(".output.dense", ".fc2")
|
||||
key = key.replace(".LayerNorm", ".final_layer_norm")
|
||||
|
||||
return key
|
||||
|
||||
|
||||
def load_weights(model_path: Path) -> Dict:
|
||||
"""Load weights from safetensors or PyTorch bin files."""
|
||||
|
||||
# Check for safetensors file
|
||||
safetensors_path = model_path / "model.safetensors"
|
||||
if safetensors_path.exists():
|
||||
print("Loading from safetensors...")
|
||||
return mx.load(str(safetensors_path))
|
||||
|
||||
# Check for single bin file
|
||||
single_bin_path = model_path / "pytorch_model.bin"
|
||||
if single_bin_path.exists():
|
||||
print("Loading from pytorch_model.bin...")
|
||||
state_dict = torch.load(str(single_bin_path), map_location="cpu")
|
||||
return {k: v.numpy() for k, v in state_dict.items()}
|
||||
|
||||
# Check for sharded bin files
|
||||
index_file = model_path / "pytorch_model.bin.index.json"
|
||||
if index_file.exists():
|
||||
print("Loading from sharded bin files...")
|
||||
with open(index_file) as f:
|
||||
index = json.load(f)
|
||||
|
||||
# Get unique shard files
|
||||
shard_files = set(index["weight_map"].values())
|
||||
|
||||
# Load all shards
|
||||
state_dict = {}
|
||||
for shard_file in sorted(shard_files):
|
||||
print(f" Loading shard: {shard_file}")
|
||||
shard_path = model_path / shard_file
|
||||
shard_dict = torch.load(str(shard_path), map_location="cpu")
|
||||
state_dict.update(shard_dict)
|
||||
|
||||
return {k: v.numpy() for k, v in state_dict.items()}
|
||||
|
||||
raise ValueError(f"No model weights found in {model_path}")
|
||||
|
||||
|
||||
def convert(model_path: Path) -> Dict[str, mx.array]:
|
||||
"""Convert ESM weights to MLX format."""
|
||||
|
||||
# Load weights from any format
|
||||
weights = load_weights(model_path)
|
||||
|
||||
# Convert keys and create MLX arrays
|
||||
mlx_weights = {}
|
||||
for key, value in weights.items():
|
||||
mlx_key = remap_key(key)
|
||||
if mlx_key is not None:
|
||||
mlx_weights[mlx_key] = (
|
||||
mx.array(value) if not isinstance(value, mx.array) else value
|
||||
)
|
||||
|
||||
# If lm_head.weight is missing but embed_tokens.weight exists, set up weight sharing
|
||||
# (This is for smaller models that don't have a separate lm_head.decoder.weight)
|
||||
if "lm_head.weight" not in mlx_weights and "embed_tokens.weight" in mlx_weights:
|
||||
mlx_weights["lm_head.weight"] = mlx_weights["embed_tokens.weight"]
|
||||
|
||||
return mlx_weights
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Convert ESM weights to MLX format")
|
||||
parser.add_argument(
|
||||
"--hf-path", default="facebook/esm2_t6_8M_UR50D", help="Hugging Face model path"
|
||||
)
|
||||
parser.add_argument("--mlx-path", default=None, help="Output path for MLX model")
|
||||
parser.add_argument(
|
||||
"--checkpoints-dir",
|
||||
default="checkpoints",
|
||||
help="Directory to save checkpoints (default: checkpoints)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Download model
|
||||
print(f"Downloading {args.hf_path}...")
|
||||
model_path = download(args.hf_path)
|
||||
|
||||
# Set output path
|
||||
if args.mlx_path is None:
|
||||
model_name = args.hf_path.split("/")[-1]
|
||||
checkpoints_dir = Path(args.checkpoints_dir)
|
||||
checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
||||
args.mlx_path = checkpoints_dir / f"mlx-{model_name}"
|
||||
mlx_path = Path(args.mlx_path)
|
||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Convert weights
|
||||
print("Converting weights...")
|
||||
mlx_weights = convert(model_path)
|
||||
|
||||
# Save weights
|
||||
print(f"Saving MLX weights to {mlx_path}...")
|
||||
mx.save_safetensors(str(mlx_path / "model.safetensors"), mlx_weights)
|
||||
|
||||
# Copy config and other files
|
||||
print("Copying config...")
|
||||
shutil.copy(model_path / "config.json", mlx_path / "config.json")
|
||||
|
||||
for file_name in ["special_tokens_map.json", "tokenizer.json", "vocab.txt"]:
|
||||
src_file = model_path / file_name
|
||||
if src_file.exists():
|
||||
shutil.copy(src_file, mlx_path / file_name)
|
||||
|
||||
print(f"Conversion complete! MLX model saved to {mlx_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
19
esm/esm/__init__.py
Normal file
19
esm/esm/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
"""
|
||||
ESM-2 protein language model implementation in MLX
|
||||
"""
|
||||
|
||||
from .attention import MultiheadAttention
|
||||
from .model import ESM2
|
||||
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
|
||||
from .rotary_embedding import RotaryEmbedding
|
||||
from .tokenizer import ProteinTokenizer
|
||||
|
||||
__all__ = [
|
||||
"ESM2",
|
||||
"ProteinTokenizer",
|
||||
"ContactPredictionHead",
|
||||
"RobertaLMHead",
|
||||
"TransformerLayer",
|
||||
"MultiheadAttention",
|
||||
"RotaryEmbedding",
|
||||
]
|
153
esm/esm/attention.py
Normal file
153
esm/esm/attention.py
Normal file
@ -0,0 +1,153 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .rotary_embedding import RotaryEmbedding
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
"""
|
||||
Multi-head attention layer with rotary position embeddings, as used in ESM-2.
|
||||
|
||||
This module implements both self-attention (when `key` and `value` are not
|
||||
provided) and cross-attention. It projects input sequences into queries,
|
||||
keys, and values, applies rotary position embeddings to encode relative
|
||||
position information, computes scaled dot-product attention over multiple
|
||||
heads in parallel, and returns a combined output projection.
|
||||
|
||||
Args:
|
||||
embed_dim (int): Total embedding dimension of the model input and output.
|
||||
num_heads (int): Number of parallel attention heads. Must divide `embed_dim`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
# Linear projections for queries, keys, and values (with bias)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||
|
||||
# Linear projection for output (with bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||
|
||||
# ESM-2 uses rotary embeddings
|
||||
self.rot_emb = RotaryEmbedding(dim=self.head_dim)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
query,
|
||||
key: Optional[mx.array] = None,
|
||||
value: Optional[mx.array] = None,
|
||||
key_padding_mask: Optional[mx.array] = None,
|
||||
attn_mask: Optional[mx.array] = None,
|
||||
need_head_weights: bool = False,
|
||||
) -> Tuple[mx.array, Optional[mx.array]]:
|
||||
"""
|
||||
Multi-head attention forward pass.
|
||||
|
||||
Args:
|
||||
query: Tensor of shape (tgt_len, batch, embed_dim).
|
||||
key: Optional tensor of shape (src_len, batch, embed_dim). Defaults to `query`.
|
||||
value: Optional tensor of shape (src_len, batch, embed_dim). Defaults to `query`.
|
||||
key_padding_mask: Optional mask of shape (batch, src_len) to ignore padded positions.
|
||||
attn_mask: Optional mask for attention (e.g., causal mask).
|
||||
need_head_weights: If True, return attention weights for each head separately.
|
||||
|
||||
Returns:
|
||||
attn_output: Tensor of shape (tgt_len, batch, embed_dim).
|
||||
attn_weights_out: Attention weights of shape
|
||||
(num_heads, batch, tgt_len, src_len) if per-head,
|
||||
or (batch, tgt_len, src_len) if averaged.
|
||||
"""
|
||||
|
||||
tgt_len, bsz, embed_dim = query.shape
|
||||
assert embed_dim == self.embed_dim
|
||||
|
||||
# For self-attention, use query as key and value if not provided
|
||||
if key is None:
|
||||
key = query
|
||||
if value is None:
|
||||
value = query
|
||||
|
||||
# Project queries, keys, values
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(value)
|
||||
|
||||
q = q * self.scaling
|
||||
|
||||
# Reshape for multi-head attention
|
||||
q = q.reshape(tgt_len, bsz * self.num_heads, self.head_dim).swapaxes(0, 1)
|
||||
k = k.reshape(-1, bsz * self.num_heads, self.head_dim).swapaxes(0, 1)
|
||||
v = v.reshape(-1, bsz * self.num_heads, self.head_dim).swapaxes(0, 1)
|
||||
|
||||
src_len = k.shape[1]
|
||||
|
||||
# Apply rotary embeddings if present
|
||||
if self.rot_emb:
|
||||
q, k = self.rot_emb(q, k)
|
||||
|
||||
# Compute attention weights
|
||||
attn_weights = q @ k.swapaxes(-2, -1)
|
||||
|
||||
assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
|
||||
|
||||
# Apply attention mask
|
||||
if attn_mask is not None:
|
||||
attn_mask = mx.expand_dims(attn_mask, 0)
|
||||
attn_weights = attn_weights + attn_mask
|
||||
|
||||
# Apply key padding mask
|
||||
if key_padding_mask is not None:
|
||||
attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len)
|
||||
# Convert key_padding_mask to boolean and expand dimensions
|
||||
# key_padding_mask: [bsz, src_len] -> [bsz, 1, 1, src_len]
|
||||
mask = mx.expand_dims(
|
||||
mx.expand_dims(key_padding_mask.astype(mx.bool_), 1), 2
|
||||
)
|
||||
# Apply mask: set attention to -inf where mask is True (padded positions)
|
||||
attn_weights = mx.where(mask, -mx.inf, attn_weights)
|
||||
attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
# Apply softmax
|
||||
attn_weights_float = mx.softmax(attn_weights.astype(mx.float32), axis=-1)
|
||||
attn_probs = attn_weights_float
|
||||
|
||||
# Compute attention output
|
||||
attn = attn_probs @ v
|
||||
assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
||||
|
||||
# Reshape output
|
||||
attn = attn.swapaxes(0, 1).reshape(tgt_len, bsz, embed_dim)
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
# Return attention weights if requested
|
||||
attn_weights_out: Optional[mx.array] = None
|
||||
if need_head_weights:
|
||||
# Return attention weights for each head separately
|
||||
attn_weights_out = (
|
||||
attn_weights_float.reshape(bsz, self.num_heads, tgt_len, src_len)
|
||||
.astype(attn.dtype)
|
||||
.swapaxes(0, 1)
|
||||
)
|
||||
else:
|
||||
# Return averaged attention weights
|
||||
attn_weights_out = mx.mean(
|
||||
attn_weights_float.reshape(bsz, self.num_heads, tgt_len, src_len),
|
||||
axis=1,
|
||||
).astype(attn.dtype)
|
||||
|
||||
return attn, attn_weights_out
|
340
esm/esm/model.py
Normal file
340
esm/esm/model.py
Normal file
@ -0,0 +1,340 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
|
||||
from .tokenizer import ProteinTokenizer
|
||||
|
||||
|
||||
class ESM2(nn.Module):
|
||||
"""
|
||||
ESM-2 protein language model in MLX.
|
||||
|
||||
Args:
|
||||
num_layers (int): Number of transformer layers.
|
||||
embed_dim (int): Embedding dimension.
|
||||
attention_heads (int): Number of attention heads.
|
||||
tokenizer (Optional[ProteinTokenizer]): Tokenizer to use (created if None).
|
||||
token_dropout (bool): Apply token-dropout masking behavior.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 33,
|
||||
embed_dim: int = 1280,
|
||||
attention_heads: int = 20,
|
||||
tokenizer: Optional[ProteinTokenizer] = None,
|
||||
token_dropout: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
self.embed_dim = embed_dim
|
||||
self.attention_heads = attention_heads
|
||||
|
||||
# Initialize tokenizer
|
||||
if tokenizer is None:
|
||||
tokenizer = ProteinTokenizer()
|
||||
self.tokenizer = tokenizer
|
||||
self.vocab_size = len(tokenizer)
|
||||
|
||||
# Special token IDs / config
|
||||
self.padding_idx = tokenizer.pad_id
|
||||
self.mask_idx = tokenizer.mask_id
|
||||
self.cls_idx = tokenizer.cls_id
|
||||
self.eos_idx = tokenizer.eos_id
|
||||
self.prepend_bos = tokenizer.prepend_bos
|
||||
self.append_eos = tokenizer.append_eos
|
||||
self.token_dropout = token_dropout
|
||||
|
||||
self._init_submodules()
|
||||
|
||||
def _init_submodules(self) -> None:
|
||||
"""Initialize embeddings, transformer stack, and output heads."""
|
||||
self.embed_scale = 1
|
||||
|
||||
# Token embeddings
|
||||
self.embed_tokens = nn.Embedding(self.vocab_size, self.embed_dim)
|
||||
|
||||
# Transformer layers (register each layer so MLX tracks parameters)
|
||||
self._layer_indices = list(range(self.num_layers))
|
||||
for i in self._layer_indices:
|
||||
layer = TransformerLayer(
|
||||
self.embed_dim,
|
||||
4 * self.embed_dim, # FFN dimension = 4×embed_dim
|
||||
self.attention_heads,
|
||||
)
|
||||
setattr(self, f"layer_{i}", layer)
|
||||
|
||||
# Contact prediction head (uses all layers × heads attentions)
|
||||
self.contact_head = ContactPredictionHead(
|
||||
self.num_layers * self.attention_heads,
|
||||
self.prepend_bos,
|
||||
self.append_eos,
|
||||
eos_idx=self.eos_idx,
|
||||
)
|
||||
|
||||
# Final norm + LM head (tied weights)
|
||||
self.emb_layer_norm_after = nn.LayerNorm(self.embed_dim)
|
||||
self.lm_head = RobertaLMHead(
|
||||
embed_dim=self.embed_dim,
|
||||
output_dim=self.vocab_size,
|
||||
weight=self.embed_tokens.weight,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
tokens: mx.array,
|
||||
repr_layers: List[int] = [],
|
||||
need_head_weights: bool = False,
|
||||
return_contacts: bool = False,
|
||||
) -> Dict[str, mx.array]:
|
||||
"""
|
||||
Forward pass through ESM-2.
|
||||
|
||||
Args:
|
||||
tokens: Tensor of token IDs with shape (B, T).
|
||||
repr_layers: Layers to return hidden states from (0..num_layers).
|
||||
need_head_weights: If True, return attention weights.
|
||||
return_contacts: If True, compute residue-residue contact probabilities.
|
||||
|
||||
Returns:
|
||||
dict with:
|
||||
logits: (B, T, V)
|
||||
representations: {layer_idx: (B, T, E)}
|
||||
attentions: (B, L, H, T, T) if requested
|
||||
contacts: (B, T', T') if requested
|
||||
"""
|
||||
if return_contacts:
|
||||
need_head_weights = True
|
||||
|
||||
# Ensure tokens is 2D (B, T)
|
||||
if tokens.ndim == 1:
|
||||
tokens = mx.expand_dims(tokens, axis=0)
|
||||
assert tokens.ndim == 2
|
||||
|
||||
# Padding mask (B, T)
|
||||
padding_mask = mx.equal(tokens, self.padding_idx)
|
||||
|
||||
# Embed tokens (B, T, E)
|
||||
x = self.embed_scale * self.embed_tokens(tokens)
|
||||
|
||||
# Token dropout: zero masked tokens + rescale based on observed mask ratio
|
||||
if self.token_dropout:
|
||||
# x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
|
||||
mask_positions = mx.equal(tokens, self.mask_idx)
|
||||
x = mx.where(mask_positions[:, :, None], 0.0, x)
|
||||
|
||||
# x: B x T x C
|
||||
mask_ratio_train = 0.15 * 0.8
|
||||
src_lengths = mx.sum(~padding_mask, axis=-1) # Shape: (B,)
|
||||
mask_ratio_observed = mx.sum(mask_positions, axis=-1).astype(x.dtype) / src_lengths # Shape: (B,)
|
||||
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
|
||||
|
||||
# Zero out padding positions
|
||||
if padding_mask.any():
|
||||
x = x * (1 - padding_mask[:, :, None].astype(x.dtype))
|
||||
|
||||
# Track requested representations
|
||||
repr_layers = set(repr_layers)
|
||||
hidden_representations: Dict[int, mx.array] = {}
|
||||
if 0 in repr_layers:
|
||||
hidden_representations[0] = x
|
||||
|
||||
if need_head_weights:
|
||||
attn_weights: List[mx.array] = []
|
||||
|
||||
# (B, T, E) -> (T, B, E) for transformer layers
|
||||
x = mx.swapaxes(x, 0, 1)
|
||||
|
||||
# If no padding anywhere, skip the mask
|
||||
if not padding_mask.any():
|
||||
padding_mask = None
|
||||
|
||||
# Transformer stack
|
||||
for layer_idx in self._layer_indices:
|
||||
layer = getattr(self, f"layer_{layer_idx}")
|
||||
x, attn = layer(
|
||||
x,
|
||||
self_attn_padding_mask=padding_mask,
|
||||
need_head_weights=need_head_weights,
|
||||
)
|
||||
|
||||
# Save hidden representation if requested (store back as (B, T, E))
|
||||
if (layer_idx + 1) in repr_layers:
|
||||
hidden_representations[layer_idx + 1] = mx.swapaxes(x, 0, 1)
|
||||
|
||||
# Save per-layer attentions if requested (H, B, T, T) -> (B, H, T, T)
|
||||
if need_head_weights:
|
||||
attn_weights.append(mx.swapaxes(attn, 0, 1))
|
||||
|
||||
# Final layer norm, back to (B, T, E)
|
||||
x = self.emb_layer_norm_after(x)
|
||||
x = mx.swapaxes(x, 0, 1)
|
||||
|
||||
# Save final hidden if requested
|
||||
if (layer_idx + 1) in repr_layers:
|
||||
hidden_representations[layer_idx + 1] = x
|
||||
|
||||
# Language modeling logits
|
||||
x = self.lm_head(x)
|
||||
|
||||
# Build result dict
|
||||
result: Dict[str, mx.array] = {
|
||||
"logits": x,
|
||||
"representations": hidden_representations,
|
||||
}
|
||||
|
||||
# Collect attentions and optional contacts
|
||||
if need_head_weights:
|
||||
# Stack layers -> (B, L, H, T, T)
|
||||
attentions = mx.stack(attn_weights, axis=1)
|
||||
|
||||
# Mask out padded positions if present
|
||||
if padding_mask is not None:
|
||||
attention_mask = 1 - padding_mask.astype(attentions.dtype)
|
||||
attention_mask = mx.expand_dims(attention_mask, 1) * mx.expand_dims(
|
||||
attention_mask, 2
|
||||
)
|
||||
attentions = attentions * attention_mask[:, None, None, :, :]
|
||||
|
||||
result["attentions"] = attentions
|
||||
|
||||
# Compute contacts if requested
|
||||
if return_contacts:
|
||||
contacts = self.contact_head(tokens, attentions)
|
||||
result["contacts"] = contacts
|
||||
|
||||
return result
|
||||
|
||||
def predict_contacts(self, tokens: mx.array) -> mx.array:
|
||||
"""
|
||||
Predict residue-residue contacts.
|
||||
|
||||
Args:
|
||||
tokens: Tensor of shape (B, T).
|
||||
|
||||
Returns:
|
||||
mx.array: Contact probabilities of shape (B, T', T').
|
||||
"""
|
||||
return self(tokens, return_contacts=True)["contacts"]
|
||||
|
||||
def extract_features(
|
||||
self,
|
||||
tokens: mx.array,
|
||||
repr_layers: Optional[List[int]] = None,
|
||||
return_all_hiddens: bool = False,
|
||||
) -> Dict[int, mx.array]:
|
||||
"""
|
||||
Extract hidden representations from selected layers.
|
||||
|
||||
Args:
|
||||
tokens: Tensor of shape (B, T).
|
||||
repr_layers: Layer indices to return (default: last layer).
|
||||
return_all_hiddens: If True, return all layers (0..num_layers).
|
||||
|
||||
Returns:
|
||||
dict: {layer_idx: (B, T, E)} for requested layers.
|
||||
"""
|
||||
if return_all_hiddens:
|
||||
repr_layers = list(range(self.num_layers + 1))
|
||||
elif repr_layers is None:
|
||||
repr_layers = [self.num_layers]
|
||||
|
||||
result = self(tokens, repr_layers=repr_layers)
|
||||
return result["representations"]
|
||||
|
||||
def get_sequence_representations(
|
||||
self,
|
||||
tokens: mx.array,
|
||||
layer: int = -1,
|
||||
) -> mx.array:
|
||||
"""
|
||||
Average token representations into a per-sequence embedding.
|
||||
|
||||
Args:
|
||||
tokens: Tensor of shape (B, T).
|
||||
layer: Layer index to use (-1 or num_layers for last).
|
||||
|
||||
Returns:
|
||||
mx.array: Sequence embeddings of shape (B, E).
|
||||
"""
|
||||
if layer == -1:
|
||||
layer = self.num_layers
|
||||
|
||||
representations = self.extract_features(tokens, repr_layers=[layer])
|
||||
repr = representations[layer]
|
||||
|
||||
# Mask: non-padding and not CLS; optionally not EOS
|
||||
mask = mx.logical_and(
|
||||
mx.not_equal(tokens, self.padding_idx),
|
||||
mx.not_equal(tokens, self.cls_idx),
|
||||
)
|
||||
if self.append_eos:
|
||||
mask = mx.logical_and(mask, mx.not_equal(tokens, self.eos_idx))
|
||||
|
||||
# Mean over valid positions
|
||||
mask = mask[:, :, None].astype(repr.dtype)
|
||||
masked_repr = repr * mask
|
||||
seq_lens = mx.sum(mask, axis=1, keepdims=True)
|
||||
seq_repr = mx.sum(masked_repr, axis=1) / mx.maximum(seq_lens[:, :, 0], 1.0)
|
||||
|
||||
return seq_repr
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str) -> Tuple[ProteinTokenizer, "ESM2"]:
|
||||
"""
|
||||
Load model weights and config from a directory.
|
||||
|
||||
Expects:
|
||||
- config.json
|
||||
- model.safetensors
|
||||
- vocab.txt (optional, will use default if not found)
|
||||
- special_tokens_map.json (optional, will use default if not found)
|
||||
|
||||
Args:
|
||||
model_path: Path to directory with weights and config.
|
||||
|
||||
Returns:
|
||||
(tokenizer, model): Initialized tokenizer and ESM2 model.
|
||||
"""
|
||||
model_dir = Path(model_path)
|
||||
config_path = model_dir / "config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
# Check for vocab and special tokens files
|
||||
vocab_path = model_dir / "vocab.txt"
|
||||
special_tokens_path = model_dir / "special_tokens_map.json"
|
||||
|
||||
if vocab_path.exists() and special_tokens_path.exists():
|
||||
tokenizer = ProteinTokenizer(
|
||||
vocab_file=str(vocab_path),
|
||||
special_tokens_map_file=str(special_tokens_path),
|
||||
)
|
||||
else:
|
||||
tokenizer = ProteinTokenizer()
|
||||
|
||||
model = cls(
|
||||
num_layers=config["num_hidden_layers"],
|
||||
embed_dim=config["hidden_size"],
|
||||
attention_heads=config["num_attention_heads"],
|
||||
tokenizer=tokenizer,
|
||||
token_dropout=config["token_dropout"],
|
||||
)
|
||||
|
||||
# Load safetensors as nested dict and update model params
|
||||
weights_path = model_dir / "model.safetensors"
|
||||
flat_weights = mx.load(str(weights_path))
|
||||
nested_weights: Dict[str, dict] = {}
|
||||
for key, value in flat_weights.items():
|
||||
parts = key.split(".")
|
||||
cur = nested_weights
|
||||
for p in parts[:-1]:
|
||||
cur = cur.setdefault(p, {})
|
||||
cur[parts[-1]] = value
|
||||
|
||||
model.update(nested_weights)
|
||||
return tokenizer, model
|
212
esm/esm/modules.py
Normal file
212
esm/esm/modules.py
Normal file
@ -0,0 +1,212 @@
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .attention import MultiheadAttention
|
||||
|
||||
|
||||
def symmetrize(x: mx.array) -> mx.array:
|
||||
"""
|
||||
Make a tensor symmetric over its last two dimensions.
|
||||
|
||||
Args:
|
||||
x: Tensor with shape (..., L, L).
|
||||
|
||||
Returns:
|
||||
mx.array: Symmetrized tensor of shape (..., L, L).
|
||||
"""
|
||||
# Add tensor to its transpose over the last two dims
|
||||
return x + mx.swapaxes(x, -1, -2)
|
||||
|
||||
|
||||
def apc(x: mx.array) -> mx.array:
|
||||
"""
|
||||
Apply Average Product Correction (APC) to remove background co-variation.
|
||||
|
||||
Args:
|
||||
x: Tensor with shape (..., L, L).
|
||||
|
||||
Returns:
|
||||
mx.array: APC-corrected tensor of shape (..., L, L).
|
||||
"""
|
||||
# Compute row, column, and total sums
|
||||
a1 = mx.sum(x, axis=-1, keepdims=True)
|
||||
a2 = mx.sum(x, axis=-2, keepdims=True)
|
||||
a12 = mx.sum(x, axis=(-1, -2), keepdims=True)
|
||||
|
||||
# Expected co-variation under independence
|
||||
expected = (a1 * a2) / a12
|
||||
return x - expected
|
||||
|
||||
|
||||
class TransformerLayer(nn.Module):
|
||||
"""
|
||||
Transformer layer used in ESM-2.
|
||||
|
||||
Args:
|
||||
embed_dim (int): Model embedding dimension.
|
||||
ffn_embed_dim (int): Hidden dimension of the feed-forward network.
|
||||
attention_heads (int): Number of attention heads.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
ffn_embed_dim: int,
|
||||
attention_heads: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.ffn_embed_dim = ffn_embed_dim
|
||||
self.attention_heads = attention_heads
|
||||
self._init_submodules()
|
||||
|
||||
def _init_submodules(self) -> None:
|
||||
"""Initialize attention, norms, and feed-forward submodules."""
|
||||
self.self_attn = MultiheadAttention(self.embed_dim, self.attention_heads)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim)
|
||||
self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim)
|
||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
self_attn_mask: Optional[mx.array] = None,
|
||||
self_attn_padding_mask: Optional[mx.array] = None,
|
||||
need_head_weights: bool = False,
|
||||
):
|
||||
"""
|
||||
Forward pass for the Transformer layer.
|
||||
|
||||
Args:
|
||||
x: Tensor of shape (seq_len, batch, embed_dim).
|
||||
self_attn_mask: Optional attention mask.
|
||||
self_attn_padding_mask: Optional padding mask of shape (batch, seq_len).
|
||||
need_head_weights: If True, return per-head attention weights.
|
||||
|
||||
Returns:
|
||||
x: Tensor of shape (seq_len, batch, embed_dim).
|
||||
attn: Attention weights of shape
|
||||
(num_heads, batch, tgt_len, src_len) if per-head,
|
||||
or (batch, tgt_len, src_len) if averaged.
|
||||
"""
|
||||
# Self-attention block
|
||||
residual = x
|
||||
x = self.self_attn_layer_norm(x)
|
||||
x, attn = self.self_attn(
|
||||
query=x,
|
||||
key_padding_mask=self_attn_padding_mask,
|
||||
attn_mask=self_attn_mask,
|
||||
need_head_weights=need_head_weights,
|
||||
)
|
||||
x = residual + x
|
||||
|
||||
# Feed-forward block
|
||||
residual = x
|
||||
x = self.final_layer_norm(x)
|
||||
x = nn.gelu(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
x = residual + x
|
||||
|
||||
return x, attn
|
||||
|
||||
|
||||
class RobertaLMHead(nn.Module):
|
||||
"""
|
||||
Masked Language Modeling (MLM) head with tied weights.
|
||||
|
||||
Args:
|
||||
embed_dim (int): Embedding dimension of the backbone.
|
||||
output_dim (int): Vocabulary size.
|
||||
weight (mx.array): Embedding weight matrix for tied projection.
|
||||
"""
|
||||
|
||||
def __init__(self, embed_dim: int, output_dim: int, weight: mx.array):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(embed_dim, embed_dim)
|
||||
self.layer_norm = nn.LayerNorm(embed_dim)
|
||||
self.weight = weight
|
||||
self.bias = mx.zeros(output_dim)
|
||||
|
||||
def __call__(self, features: mx.array) -> mx.array:
|
||||
"""
|
||||
Forward pass for the MLM head.
|
||||
|
||||
Args:
|
||||
features: Tensor of shape (seq_len, batch, embed_dim).
|
||||
|
||||
Returns:
|
||||
mx.array: Logits of shape (seq_len, batch, output_dim).
|
||||
"""
|
||||
# Transform features before projection to vocab
|
||||
x = self.dense(features)
|
||||
x = nn.gelu(x)
|
||||
x = self.layer_norm(x)
|
||||
return mx.matmul(x, self.weight.T) + self.bias
|
||||
|
||||
|
||||
class ContactPredictionHead(nn.Module):
|
||||
"""
|
||||
Predict residue-residue contact probabilities from attention maps.
|
||||
|
||||
Args:
|
||||
in_features (int): Number of attention channels (layers × heads).
|
||||
prepend_bos (bool): If True, drop BOS/CLS token attentions.
|
||||
append_eos (bool): If True, drop EOS token attentions.
|
||||
bias (bool): Whether the regression layer uses a bias term.
|
||||
eos_idx (Optional[int]): Token ID for EOS; required if append_eos=True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
prepend_bos: bool,
|
||||
append_eos: bool,
|
||||
bias: bool = True,
|
||||
eos_idx: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.prepend_bos = prepend_bos
|
||||
self.append_eos = append_eos
|
||||
if append_eos and eos_idx is None:
|
||||
raise ValueError("append_eos=True but eos_idx was not provided.")
|
||||
self.eos_idx = eos_idx
|
||||
self.regression = nn.Linear(in_features, 1, bias=bias)
|
||||
|
||||
def __call__(self, tokens: mx.array, attentions: mx.array) -> mx.array:
|
||||
"""
|
||||
Forward pass for contact prediction.
|
||||
|
||||
Args:
|
||||
tokens: Tensor of shape (B, T).
|
||||
attentions: Tensor of shape (B, L, H, T, T).
|
||||
|
||||
Returns:
|
||||
mx.array: Contact probabilities of shape (B, T', T'),
|
||||
where T' = T - [prepend_bos] - [append_eos].
|
||||
"""
|
||||
# Remove EOS attentions if requested
|
||||
if self.append_eos:
|
||||
eos_mask = mx.not_equal(tokens, self.eos_idx).astype(attentions.dtype)
|
||||
eos_mask = eos_mask[:, None, :] * eos_mask[:, :, None]
|
||||
attentions = attentions * eos_mask[:, None, None, :, :]
|
||||
attentions = attentions[..., :-1, :-1]
|
||||
|
||||
# Remove BOS attentions if requested
|
||||
if self.prepend_bos:
|
||||
attentions = attentions[..., 1:, 1:]
|
||||
|
||||
# Merge (layers × heads) into channel dimension
|
||||
batch_size, layers, heads, seqlen, _ = attentions.shape
|
||||
attentions = attentions.reshape(batch_size, layers * heads, seqlen, seqlen)
|
||||
|
||||
# Symmetrize and apply APC to enhance contact signal
|
||||
attentions = apc(symmetrize(attentions))
|
||||
|
||||
# Apply logistic regression over channels
|
||||
attentions = mx.transpose(attentions, axes=[0, 2, 3, 1])
|
||||
logits = self.regression(attentions)
|
||||
return nn.sigmoid(mx.squeeze(logits, axis=3))
|
114
esm/esm/rotary_embedding.py
Normal file
114
esm/esm/rotary_embedding.py
Normal file
@ -0,0 +1,114 @@
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
def rotate_half(x: mx.array) -> mx.array:
|
||||
"""
|
||||
Rotate last dimension by splitting into two halves and swapping.
|
||||
|
||||
Args:
|
||||
x: Tensor with even-sized last dimension.
|
||||
|
||||
Returns:
|
||||
mx.array: Tensor of same shape as `x` with halves rotated.
|
||||
"""
|
||||
# Split into two equal halves along the last dimension
|
||||
x1, x2 = mx.split(x, 2, axis=-1)
|
||||
# Swap halves and negate the second half
|
||||
return mx.concatenate((-x2, x1), axis=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(x: mx.array, cos: mx.array, sin: mx.array) -> mx.array:
|
||||
"""
|
||||
Apply rotary position embeddings to a tensor.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (..., seq_len, dim).
|
||||
cos: Cosine embedding table of shape (1, seq_len, dim).
|
||||
sin: Sine embedding table of shape (1, seq_len, dim).
|
||||
|
||||
Returns:
|
||||
mx.array: Tensor with rotary position embeddings applied.
|
||||
"""
|
||||
# Trim cos/sin to match the sequence length of x
|
||||
cos = cos[:, : x.shape[-2], :]
|
||||
sin = sin[:, : x.shape[-2], :]
|
||||
|
||||
# Elementwise rotation: (x * cos) + (rotate_half(x) * sin)
|
||||
return (x * cos) + (rotate_half(x) * sin)
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
"""
|
||||
Rotary position embedding (RoPE) module.
|
||||
|
||||
Args:
|
||||
dim (int): Head dimension size (must be even).
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, *_, **__):
|
||||
super().__init__()
|
||||
# Precompute inverse frequency for each pair of dimensions
|
||||
self.inv_freq = 1.0 / (10000 ** (mx.arange(0, dim, 2).astype(mx.float32) / dim))
|
||||
|
||||
# Cache for cosine/sine tables to avoid recomputation
|
||||
self._seq_len_cached = None
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
|
||||
def _update_cos_sin_tables(
|
||||
self, x: mx.array, seq_dimension: int = 1
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""
|
||||
Compute and cache cos/sin tables for the given sequence length.
|
||||
|
||||
Args:
|
||||
x: Reference tensor for sequence length.
|
||||
seq_dimension: Axis containing the sequence length.
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
cos: Cosine table of shape (1, seq_len, dim).
|
||||
sin: Sine table of shape (1, seq_len, dim).
|
||||
"""
|
||||
seq_len = x.shape[seq_dimension]
|
||||
# Only update cache if sequence length has changed
|
||||
if seq_len != self._seq_len_cached:
|
||||
self._seq_len_cached = seq_len
|
||||
# Time steps: shape (seq_len,)
|
||||
t = mx.arange(seq_len).astype(self.inv_freq.dtype)
|
||||
# Outer product between time and inverse frequency
|
||||
freqs = mx.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Duplicate frequencies for cos/sin dimensions
|
||||
emb = mx.concatenate((freqs, freqs), axis=-1)
|
||||
|
||||
self._cos_cached = mx.cos(emb)[None, :, :]
|
||||
self._sin_cached = mx.sin(emb)[None, :, :]
|
||||
|
||||
return self._cos_cached, self._sin_cached
|
||||
|
||||
def __call__(self, q: mx.array, k: mx.array) -> Tuple[mx.array, mx.array]:
|
||||
"""
|
||||
Apply rotary position embeddings to queries and keys.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape (..., seq_len, dim).
|
||||
k: Key tensor of shape (..., seq_len, dim).
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
q_rot: Query tensor with RoPE applied.
|
||||
k_rot: Key tensor with RoPE applied.
|
||||
"""
|
||||
# Get (and cache) cos/sin tables based on key sequence length
|
||||
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
|
||||
k, seq_dimension=-2
|
||||
)
|
||||
|
||||
# Apply rotary embeddings to both q and k
|
||||
return (
|
||||
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
|
||||
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
|
||||
)
|
241
esm/esm/tokenizer.py
Normal file
241
esm/esm/tokenizer.py
Normal file
@ -0,0 +1,241 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
# Canonical amino-acid tokens (IUPAC standard + uncommon variants)
|
||||
PROTEIN_TOKENS = [
|
||||
"L",
|
||||
"A",
|
||||
"G",
|
||||
"V",
|
||||
"S",
|
||||
"E",
|
||||
"R",
|
||||
"T",
|
||||
"I",
|
||||
"D",
|
||||
"P",
|
||||
"K",
|
||||
"Q",
|
||||
"N",
|
||||
"F",
|
||||
"Y",
|
||||
"M",
|
||||
"H",
|
||||
"W",
|
||||
"C",
|
||||
"X",
|
||||
"B",
|
||||
"U",
|
||||
"Z",
|
||||
"O",
|
||||
".",
|
||||
"-",
|
||||
]
|
||||
|
||||
ArrayLike = Union[List[int], mx.array]
|
||||
|
||||
|
||||
class ProteinTokenizer:
|
||||
"""
|
||||
Protein sequence tokenizer compatible with ESM-2.
|
||||
|
||||
This class converts protein sequences into token IDs and back, following
|
||||
the vocabulary, special tokens, and formatting rules used by ESM-2.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file: Optional[str] = None,
|
||||
special_tokens_map_file: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the ProteinTokenizer.
|
||||
|
||||
Args:
|
||||
vocab_file: Optional path to a file containing the vocabulary,
|
||||
one token per line.
|
||||
special_tokens_map_file: Optional path to a JSON file defining
|
||||
special token names and values.
|
||||
|
||||
If both files are provided, they override the default vocabulary and
|
||||
special token mappings. Otherwise, defaults are loaded.
|
||||
"""
|
||||
|
||||
# Load vocabulary from files if given, otherwise use built-in defaults
|
||||
if vocab_file and special_tokens_map_file:
|
||||
self._load_from_files(vocab_file, special_tokens_map_file)
|
||||
else:
|
||||
self._load_default_vocab()
|
||||
|
||||
# Create token ↔ ID mappings
|
||||
self.token_to_id = {tok: i for i, tok in enumerate(self.vocab)}
|
||||
self.id_to_token = {i: tok for i, tok in enumerate(self.vocab)}
|
||||
|
||||
# Cache special token IDs
|
||||
self.cls_id = self.token_to_id["<cls>"]
|
||||
self.pad_id = self.token_to_id["<pad>"]
|
||||
self.eos_id = self.token_to_id["<eos>"]
|
||||
self.unk_id = self.token_to_id["<unk>"]
|
||||
self.mask_id = self.token_to_id["<mask>"]
|
||||
|
||||
# Behavior flags for ESM-2-style BOS/EOS
|
||||
self.prepend_bos = True
|
||||
self.append_eos = True
|
||||
|
||||
def _load_from_files(self, vocab_file: str, special_tokens_map_file: str) -> None:
|
||||
"""Load vocabulary and special tokens from the provided files."""
|
||||
# Vocabulary file: one token per line
|
||||
vocab_path = Path(vocab_file)
|
||||
with open(vocab_path, "r", encoding="utf-8") as f:
|
||||
self.vocab = [line.strip() for line in f if line.strip()]
|
||||
|
||||
# Special tokens mapping file (JSON)
|
||||
special_tokens_path = Path(special_tokens_map_file)
|
||||
with open(special_tokens_path, "r", encoding="utf-8") as f:
|
||||
self.special_tokens_map = json.load(f)
|
||||
|
||||
def _load_default_vocab(self) -> None:
|
||||
"""Load the built-in ESM vocabulary and special token mapping."""
|
||||
# ESM convention: prepend special tokens, then amino acids, then <mask>
|
||||
prepend_toks = ["<cls>", "<pad>", "<eos>", "<unk>"]
|
||||
append_toks = ["<mask>"]
|
||||
|
||||
self.vocab = prepend_toks + PROTEIN_TOKENS
|
||||
|
||||
# Pad vocab size to multiple of 8 (original implementation detail)
|
||||
while len(self.vocab) % 8 != 0:
|
||||
self.vocab.append(f"<null_{len(self.vocab) - len(prepend_toks)}>")
|
||||
|
||||
self.vocab.extend(append_toks)
|
||||
|
||||
# Default special tokens map
|
||||
self.special_tokens_map = {
|
||||
"cls_token": "<cls>",
|
||||
"pad_token": "<pad>",
|
||||
"eos_token": "<eos>",
|
||||
"unk_token": "<unk>",
|
||||
"mask_token": "<mask>",
|
||||
}
|
||||
|
||||
def encode(
|
||||
self,
|
||||
sequence: str,
|
||||
*,
|
||||
add_special_tokens: bool = True,
|
||||
return_batch_dim: bool = False,
|
||||
dtype=mx.int32,
|
||||
) -> mx.array:
|
||||
"""
|
||||
Convert a protein sequence into token IDs.
|
||||
|
||||
Args:
|
||||
sequence: Protein sequence (case-insensitive).
|
||||
add_special_tokens: If True, add <cls> at the start and <eos> at the end.
|
||||
return_batch_dim: If True, output shape will be (1, L) instead of (L,).
|
||||
dtype: MLX dtype for the returned array.
|
||||
|
||||
Returns:
|
||||
mx.array: Token IDs of shape (L,) or (1, L).
|
||||
"""
|
||||
ids: List[int] = []
|
||||
|
||||
if add_special_tokens and self.prepend_bos:
|
||||
ids.append(self.cls_id)
|
||||
|
||||
# Map each residue to its ID (defaulting to <unk> if not in vocab)
|
||||
for ch in sequence.upper():
|
||||
ids.append(self.token_to_id.get(ch, self.unk_id))
|
||||
|
||||
if add_special_tokens and self.append_eos:
|
||||
ids.append(self.eos_id)
|
||||
|
||||
arr = mx.array(ids, dtype=dtype)
|
||||
return mx.expand_dims(arr, axis=0) if return_batch_dim else arr
|
||||
|
||||
def batch_encode(
|
||||
self,
|
||||
sequences: Sequence[str],
|
||||
*,
|
||||
add_special_tokens: bool = True,
|
||||
max_length: Optional[int] = None,
|
||||
dtype=mx.int32,
|
||||
) -> mx.array:
|
||||
"""
|
||||
Encode multiple protein sequences into a padded batch.
|
||||
|
||||
Args:
|
||||
sequences: List/sequence of protein strings.
|
||||
add_special_tokens: If True, add <cls> and <eos> tokens.
|
||||
max_length: If provided, truncate sequences to this length before padding.
|
||||
dtype: MLX dtype for the returned array.
|
||||
|
||||
Returns:
|
||||
mx.array: Tensor of shape (B, L) with right-padding using <pad> IDs.
|
||||
"""
|
||||
# Encode each sequence as (L,)
|
||||
encoded = [
|
||||
self.encode(s, add_special_tokens=add_special_tokens, dtype=dtype)
|
||||
for s in sequences
|
||||
]
|
||||
encoded = [e if e.ndim == 1 else e[0] for e in encoded]
|
||||
|
||||
if max_length is not None:
|
||||
encoded = [e[:max_length] for e in encoded]
|
||||
|
||||
# Find the longest sequence and right-pad all others
|
||||
max_len = max((int(e.shape[0]) for e in encoded), default=0)
|
||||
padded = []
|
||||
for e in encoded:
|
||||
pad_len = max_len - int(e.shape[0])
|
||||
if pad_len > 0:
|
||||
pad = mx.full((pad_len,), self.pad_id, dtype=dtype)
|
||||
e = mx.concatenate([e, pad], axis=0)
|
||||
padded.append(e)
|
||||
|
||||
return mx.stack(padded, axis=0) if padded else mx.array([], dtype=dtype)
|
||||
|
||||
def decode(
|
||||
self,
|
||||
token_ids: ArrayLike,
|
||||
*,
|
||||
skip_special_tokens: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Convert token IDs back into a protein sequence string.
|
||||
|
||||
Args:
|
||||
token_ids: 1-D or 2-D array/list of IDs. If 2-D, only the first row is decoded.
|
||||
skip_special_tokens: If True, remove all special tokens from output.
|
||||
|
||||
Returns:
|
||||
str: Protein sequence.
|
||||
"""
|
||||
# Normalize to a 1-D MLX array
|
||||
if hasattr(token_ids, "shape") and hasattr(token_ids, "tolist"):
|
||||
ids = token_ids if token_ids.ndim == 1 else token_ids[0]
|
||||
else:
|
||||
ids = mx.array(token_ids, dtype=mx.int32)
|
||||
|
||||
ids_list = [int(x) for x in ids.tolist()]
|
||||
toks: List[str] = []
|
||||
|
||||
for i in ids_list:
|
||||
tok = self.id_to_token.get(i, "<unk>")
|
||||
if skip_special_tokens and tok in {
|
||||
"<cls>",
|
||||
"<pad>",
|
||||
"<eos>",
|
||||
"<unk>",
|
||||
"<mask>",
|
||||
}:
|
||||
continue
|
||||
toks.append(tok)
|
||||
|
||||
return "".join(toks)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the size of the tokenizer’s vocabulary."""
|
||||
return len(self.vocab)
|
81
esm/main.py
Normal file
81
esm/main.py
Normal file
@ -0,0 +1,81 @@
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from esm import ESM2
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ESM-2 MLX Inference")
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
default="checkpoints/mlx-esm2_t33_650M_UR50D",
|
||||
help="Path to MLX model checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sequence",
|
||||
default="MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
|
||||
help="Protein sequence to test (default: human insulin)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mask-position",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Position to mask (default: middle of sequence)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load pretrained ESM-2 model and tokenizer
|
||||
tokenizer, model = ESM2.from_pretrained(args.model_path)
|
||||
|
||||
# Determine sequence and position to mask
|
||||
sequence = args.sequence.upper()
|
||||
mask_pos = (
|
||||
args.mask_position if args.mask_position is not None else len(sequence) // 2
|
||||
)
|
||||
if mask_pos >= len(sequence):
|
||||
mask_pos = len(sequence) - 1
|
||||
original_aa = sequence[mask_pos] # The original amino acid at masked position
|
||||
|
||||
# Display input info
|
||||
print(f"Original sequence: {sequence}")
|
||||
print(f"Masked sequence: {sequence[:mask_pos]}<mask>{sequence[mask_pos+1:]}")
|
||||
print(f"Predicting position {mask_pos}: {original_aa}\n")
|
||||
|
||||
# Tokenize sequence before and after the mask
|
||||
before = tokenizer.encode(sequence[:mask_pos], add_special_tokens=False)
|
||||
after = tokenizer.encode(sequence[mask_pos + 1 :], add_special_tokens=False)
|
||||
|
||||
# Build token sequence with <cls>, <mask>, and <eos>
|
||||
tokens = mx.array(
|
||||
[
|
||||
[tokenizer.cls_id]
|
||||
+ before.tolist()
|
||||
+ [tokenizer.mask_id]
|
||||
+ after.tolist()
|
||||
+ [tokenizer.eos_id]
|
||||
]
|
||||
)
|
||||
mask_token_pos = 1 + len(before) # Position of <mask> token
|
||||
|
||||
# Run model to get logits for each token position
|
||||
logits = model(tokens)["logits"]
|
||||
probs = mx.softmax(
|
||||
logits[0, mask_token_pos, :]
|
||||
) # Softmax over vocabulary at mask position
|
||||
|
||||
# Get top-5 most likely tokens
|
||||
top_indices = mx.argsort(probs)[-5:][::-1]
|
||||
|
||||
# Print predictions
|
||||
print("Top predictions:")
|
||||
for i, idx in enumerate(top_indices):
|
||||
token = tokenizer.vocab[int(idx)]
|
||||
if token in tokenizer.vocab:
|
||||
prob = float(probs[idx])
|
||||
marker = "✓" if token == original_aa else " "
|
||||
print(f"{marker} {i+1}. {token}: {prob:.3f} ({prob*100:.1f}%)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
602
esm/notebooks/contact_prediction.ipynb
Normal file
602
esm/notebooks/contact_prediction.ipynb
Normal file
@ -0,0 +1,602 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3fbacbe4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Predicting Protein Contacts with ESM-2\n",
|
||||
"\n",
|
||||
"Understanding how amino acids interact within a folded protein is essential for grasping protein function and stability. Contact prediction, the task of identifying which residues are close together in three-dimensional space, is a key step in the sequence to structure process. ESM-2’s learned attention patterns capture evolutionary signals that encode structural information, which allows the model to predict residue contacts directly from sequence data.\n",
|
||||
"\n",
|
||||
"In this notebook, we'll explore ESM-2's ability to predict protein contacts across three diverse proteins from different organisms:\n",
|
||||
"\n",
|
||||
"**Bacterial Transport:**\n",
|
||||
"- **1a3a (PTS Mannitol Component)**: A phosphoenolpyruvate-dependent sugar phosphotransferase system component from *E. coli*, essential for carbohydrate metabolism\n",
|
||||
"\n",
|
||||
"**Stress Response:**\n",
|
||||
"- **5ahw (Universal Stress Protein)**: A conserved stress response protein from *Mycolicibacterium smegmatis* that helps cells survive harsh conditions\n",
|
||||
"\n",
|
||||
"**Human Metabolism:**\n",
|
||||
"- **1xcr (Ester Hydrolase)**: A human enzyme (C11orf54) involved in lipid metabolism and cellular signaling pathways\n",
|
||||
"\n",
|
||||
"We will evaluate how effectively ESM-2 captures the structural relationships present in these sequences, measuring precision across different sequence separation ranges to assess both local and long-range contact prediction performance. This notebook is a modified version of a [notebook by the same name](https://github.com/facebookresearch/esm/blob/main/examples/contact_prediction.ipynb) from the [offical ESM repsitory](https://github.com/facebookresearch/esm)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "08352b12",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Setup\n",
|
||||
"\n",
|
||||
"Here we import all neccessary libraries."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c1047c94",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "",
|
||||
"evalue": "",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[1;31mRunning cells with '.venv (Python 3.11.13)' requires the ipykernel package.\n",
|
||||
"\u001b[1;31mInstall 'ipykernel' into the Python environment. \n",
|
||||
"\u001b[1;31mCommand: '/Users/vincent/Developer/mlx-examples/.venv/bin/python -m pip install ipykernel -U --force-reinstall'"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from typing import List, Tuple, Optional, Dict\n",
|
||||
"import string\n",
|
||||
"\n",
|
||||
"import mlx.core as mx\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from scipy.spatial.distance import squareform, pdist\n",
|
||||
"import biotite.structure as bs\n",
|
||||
"from biotite.database import rcsb\n",
|
||||
"from biotite.structure.io.pdbx import CIFFile, get_structure\n",
|
||||
"from Bio import SeqIO"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5f0af076",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Download multiple sequence alignment (MSA) files for our three test proteins from the ESM repository."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3264b66d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!mkdir -p data\n",
|
||||
"!curl -o data/1a3a_1_A.a3m https://raw.githubusercontent.com/facebookresearch/esm/main/examples/data/1a3a_1_A.a3m\n",
|
||||
"!curl -o data/5ahw_1_A.a3m https://raw.githubusercontent.com/facebookresearch/esm/main/examples/data/5ahw_1_A.a3m\n",
|
||||
"!curl -o data/1xcr_1_A.a3m https://raw.githubusercontent.com/facebookresearch/esm/main/examples/data/1xcr_1_A.a3m"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cbf1d0cb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Loading the model\n",
|
||||
"\n",
|
||||
"Load the ESM-2 model. Here we will use the 650M parameter version. Change the path below to point to your converted checkpoint."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4406e8a0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"sys.path.append(\"..\")\n",
|
||||
"\n",
|
||||
"from esm import ESM2\n",
|
||||
"\n",
|
||||
"esm_checkpoint = \"../checkpoints/mlx-esm2_t33_650M_UR50D\"\n",
|
||||
"tokenizer, model = ESM2.from_pretrained(esm_checkpoint)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "77596456",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Defining functions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "eb5f07ed",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Parsing alignments"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e754abd7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This function parses multiple sequence alignment files and clean up insertion artifacts. MSA files often contain lowercase letters and special characters (`.`, `*`) to indicate insertions relative to the reference sequence - these need to be removed to get the core aligned sequences."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "43717bea",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"deletekeys = dict.fromkeys(string.ascii_lowercase)\n",
|
||||
"deletekeys[\".\"] = None\n",
|
||||
"deletekeys[\"*\"] = None\n",
|
||||
"translation = str.maketrans(deletekeys)\n",
|
||||
"\n",
|
||||
"def read_sequence(filename: str) -> Tuple[str, str]:\n",
|
||||
" \"\"\" Reads the first (reference) sequences from a fasta or MSA file.\"\"\"\n",
|
||||
" record = next(SeqIO.parse(filename, \"fasta\"))\n",
|
||||
" return record.description, str(record.seq)\n",
|
||||
"\n",
|
||||
"def remove_insertions(sequence: str) -> str:\n",
|
||||
" \"\"\" Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. \"\"\"\n",
|
||||
" return sequence.translate(translation)\n",
|
||||
"\n",
|
||||
"def read_msa(filename: str) -> List[Tuple[str, str]]:\n",
|
||||
" \"\"\" Reads the sequences from an MSA file, automatically removes insertions.\"\"\"\n",
|
||||
" return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, \"fasta\")]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "628d7de1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Converting structures to contacts\n",
|
||||
"\n",
|
||||
"There are many ways to define a protein contact. Here we're using the definition of 8 angstroms between carbon beta atoms. Note that the position of the carbon beta is imputed from the position of the N, CA, and C atoms for each residue."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "21e0b44b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def extend(a, b, c, L, A, D):\n",
|
||||
" \"\"\"\n",
|
||||
" input: 3 coords (a,b,c), (L)ength, (A)ngle, and (D)ihedral\n",
|
||||
" output: 4th coord\n",
|
||||
" \"\"\"\n",
|
||||
" def normalize(x):\n",
|
||||
" return x / np.linalg.norm(x, ord=2, axis=-1, keepdims=True)\n",
|
||||
"\n",
|
||||
" bc = normalize(b - c)\n",
|
||||
" n = normalize(np.cross(b - a, bc))\n",
|
||||
" m = [bc, np.cross(n, bc), n]\n",
|
||||
" d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)]\n",
|
||||
" return c + sum([m * d for m, d in zip(m, d)])\n",
|
||||
"\n",
|
||||
"def contacts_from_pdb(\n",
|
||||
" structure: bs.AtomArray,\n",
|
||||
" distance_threshold: float = 8.0,\n",
|
||||
" chain: Optional[str] = None,\n",
|
||||
") -> np.ndarray:\n",
|
||||
" \"\"\"Extract contacts from PDB structure.\"\"\"\n",
|
||||
" mask = ~structure.hetero\n",
|
||||
" if chain is not None:\n",
|
||||
" mask &= structure.chain_id == chain\n",
|
||||
"\n",
|
||||
" N = structure.coord[mask & (structure.atom_name == \"N\")]\n",
|
||||
" CA = structure.coord[mask & (structure.atom_name == \"CA\")]\n",
|
||||
" C = structure.coord[mask & (structure.atom_name == \"C\")]\n",
|
||||
"\n",
|
||||
" Cbeta = extend(C, N, CA, 1.522, 1.927, -2.143)\n",
|
||||
" dist = squareform(pdist(Cbeta))\n",
|
||||
" \n",
|
||||
" contacts = dist < distance_threshold\n",
|
||||
" contacts = contacts.astype(np.int64)\n",
|
||||
" contacts[np.isnan(dist)] = -1\n",
|
||||
" return contacts"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5473f306",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Computing contact precisions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e361a9f3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calculate precision metrics to evaluate contact prediction quality. The `compute_precisions` function ranks predicted contacts by confidence and measures how many of the top predictions are true contacts, while `evaluate_prediction` breaks this down by sequence separation ranges (local, short, medium, long-range) since predicting distant contacts is typically much harder than nearby ones."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "62c37bbd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def compute_precisions(\n",
|
||||
" predictions: mx.array,\n",
|
||||
" targets: mx.array,\n",
|
||||
" minsep: int = 6,\n",
|
||||
" maxsep: Optional[int] = None,\n",
|
||||
" override_length: Optional[int] = None,\n",
|
||||
") -> Dict[str, mx.array]:\n",
|
||||
" \"\"\"Compute precision metrics for contact prediction.\"\"\"\n",
|
||||
" batch_size, seqlen, _ = predictions.shape\n",
|
||||
" \n",
|
||||
" if maxsep is not None:\n",
|
||||
" sep_mask_2d = mx.abs(mx.arange(seqlen)[None, :] - mx.arange(seqlen)[:, None]) <= maxsep\n",
|
||||
" targets = targets * sep_mask_2d[None, :]\n",
|
||||
" \n",
|
||||
" targets = targets.astype(mx.float32)\n",
|
||||
" src_lengths = (targets >= 0).sum(axis=-1).sum(axis=-1).astype(mx.float32)\n",
|
||||
" \n",
|
||||
" x_ind, y_ind = [], []\n",
|
||||
" for i in range(seqlen):\n",
|
||||
" for j in range(i + minsep, seqlen):\n",
|
||||
" x_ind.append(i)\n",
|
||||
" y_ind.append(j)\n",
|
||||
" \n",
|
||||
" x_ind = mx.array(x_ind)\n",
|
||||
" y_ind = mx.array(y_ind)\n",
|
||||
" \n",
|
||||
" predictions_upper = predictions[:, x_ind, y_ind]\n",
|
||||
" targets_upper = targets[:, x_ind, y_ind]\n",
|
||||
"\n",
|
||||
" topk = seqlen if override_length is None else max(seqlen, override_length)\n",
|
||||
" indices = mx.argsort(predictions_upper, axis=-1)[:, ::-1][:, :topk]\n",
|
||||
" \n",
|
||||
" batch_indices = mx.arange(batch_size)[:, None]\n",
|
||||
" topk_targets = targets_upper[batch_indices, indices]\n",
|
||||
" \n",
|
||||
" if topk_targets.shape[1] < topk:\n",
|
||||
" pad_shape = (topk_targets.shape[0], topk - topk_targets.shape[1])\n",
|
||||
" padding = mx.zeros(pad_shape)\n",
|
||||
" topk_targets = mx.concatenate([topk_targets, padding], 1)\n",
|
||||
"\n",
|
||||
" cumulative_dist = mx.cumsum(topk_targets, -1)\n",
|
||||
"\n",
|
||||
" gather_lengths = src_lengths[:, None]\n",
|
||||
" if override_length is not None:\n",
|
||||
" gather_lengths = override_length * mx.ones_like(gather_lengths)\n",
|
||||
"\n",
|
||||
" precision_fractions = mx.arange(0.1, 1.1, 0.1)\n",
|
||||
" gather_indices = (precision_fractions[None, :] * gather_lengths) - 1\n",
|
||||
" gather_indices = mx.clip(gather_indices, 0, cumulative_dist.shape[1] - 1)\n",
|
||||
" gather_indices = gather_indices.astype(mx.int32)\n",
|
||||
"\n",
|
||||
" binned_cumulative_dist = cumulative_dist[batch_indices, gather_indices]\n",
|
||||
" binned_precisions = binned_cumulative_dist / (gather_indices + 1)\n",
|
||||
"\n",
|
||||
" pl5 = binned_precisions[:, 1]\n",
|
||||
" pl2 = binned_precisions[:, 4]\n",
|
||||
" pl = binned_precisions[:, 9]\n",
|
||||
" auc = binned_precisions.mean(-1)\n",
|
||||
"\n",
|
||||
" return {\"AUC\": auc, \"P@L\": pl, \"P@L2\": pl2, \"P@L5\": pl5}\n",
|
||||
"\n",
|
||||
"def evaluate_prediction(\n",
|
||||
" predictions: mx.array,\n",
|
||||
" targets: mx.array,\n",
|
||||
") -> Dict[str, float]:\n",
|
||||
" \"\"\"Evaluate contact predictions across different sequence separation ranges.\"\"\"\n",
|
||||
" contact_ranges = [\n",
|
||||
" (\"local\", 3, 6),\n",
|
||||
" (\"short\", 6, 12),\n",
|
||||
" (\"medium\", 12, 24),\n",
|
||||
" (\"long\", 24, None),\n",
|
||||
" ]\n",
|
||||
" metrics = {}\n",
|
||||
" \n",
|
||||
" for name, minsep, maxsep in contact_ranges:\n",
|
||||
" rangemetrics = compute_precisions(\n",
|
||||
" predictions,\n",
|
||||
" targets,\n",
|
||||
" minsep=minsep,\n",
|
||||
" maxsep=maxsep,\n",
|
||||
" )\n",
|
||||
" for key, val in rangemetrics.items():\n",
|
||||
" metrics[f\"{name}_{key}\"] = float(val[0])\n",
|
||||
" return metrics"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5873e052",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Predicting contacts"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2d5778a9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This function wraps the tokenization and model inference steps, converting a raw amino acid sequence into token IDs and passing them through ESM-2's contact prediction head to produce a contact probability matrix."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dddf31a7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def predict_contacts(sequence: str, model, tokenizer) -> mx.array:\n",
|
||||
" \"\"\" Predict contacts for a given sequence \"\"\"\n",
|
||||
" tokens = tokenizer.encode(sequence)\n",
|
||||
" contacts = model.predict_contacts(tokens)\n",
|
||||
" return contacts"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "62562401",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Plotting results\n",
|
||||
"\n",
|
||||
"This function visualizes contacts as a symmetric matrix where both axes index residue positions. The lower triangle shows the model’s confidence as a blue heatmap, with darker cells indicating higher confidence. The upper triangle overlays evaluation markers: blue dots are correctly predicted contacts (true positives), red dots are predicted but not real (false positives), and grey dots are real contacts the model missed (false negatives)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "03e03791",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def plot_contacts_and_predictions(\n",
|
||||
" predictions: mx.array,\n",
|
||||
" contacts: np.ndarray,\n",
|
||||
" ax,\n",
|
||||
" title: str,\n",
|
||||
" cmap: str = \"Blues\",\n",
|
||||
" ms: float = 1,\n",
|
||||
"):\n",
|
||||
" \"\"\"Plot contact predictions and true contacts.\"\"\"\n",
|
||||
" if isinstance(predictions, mx.array):\n",
|
||||
" predictions = np.array(predictions)\n",
|
||||
" \n",
|
||||
" seqlen = contacts.shape[0]\n",
|
||||
" relative_distance = np.add.outer(-np.arange(seqlen), np.arange(seqlen))\n",
|
||||
" bottom_mask = relative_distance < 0\n",
|
||||
" masked_image = np.ma.masked_where(bottom_mask, predictions)\n",
|
||||
" invalid_mask = np.abs(np.add.outer(np.arange(seqlen), -np.arange(seqlen))) < 6\n",
|
||||
" predictions_copy = predictions.copy()\n",
|
||||
" predictions_copy[invalid_mask] = float(\"-inf\")\n",
|
||||
"\n",
|
||||
" topl_val = np.sort(predictions_copy.reshape(-1))[-seqlen]\n",
|
||||
" pred_contacts = predictions_copy >= topl_val\n",
|
||||
" true_positives = contacts & pred_contacts & ~bottom_mask\n",
|
||||
" false_positives = ~contacts & pred_contacts & ~bottom_mask\n",
|
||||
" other_contacts = contacts & ~pred_contacts & ~bottom_mask\n",
|
||||
"\n",
|
||||
" ax.imshow(masked_image, cmap=cmap)\n",
|
||||
" ax.plot(*np.where(other_contacts), \"o\", c=\"grey\", ms=ms)\n",
|
||||
" ax.plot(*np.where(false_positives), \"o\", c=\"r\", ms=ms)\n",
|
||||
" ax.plot(*np.where(true_positives), \"o\", c=\"b\", ms=ms)\n",
|
||||
" ax.set_title(title)\n",
|
||||
" ax.axis(\"square\")\n",
|
||||
" ax.set_xlim([0, seqlen])\n",
|
||||
" ax.set_ylim([0, seqlen])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9364c984",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Predict and visualize\n",
|
||||
"Here we'll use ESM-2 contact prediction on our three test proteins and evaluate the results. We'll compute precision metrics across different sequence separation ranges and create contact maps that visualize both the model's predictions and how well they match the true protein structures."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9fa9e59e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Read Data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7da50dc2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Load experimental protein structures from the Protein Data Bank and extract true contact maps for evaluation, while also parsing the reference sequences from our MSA files that will serve as input to ESM-2."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2d276137",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"PDB_IDS = [\"1a3a\", \"5ahw\", \"1xcr\"]\n",
|
||||
"\n",
|
||||
"structures = {\n",
|
||||
" name.lower(): get_structure(CIFFile.read(rcsb.fetch(name, \"cif\")))[0]\n",
|
||||
" for name in PDB_IDS\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"contacts = {\n",
|
||||
" name: contacts_from_pdb(structure, chain=\"A\") \n",
|
||||
" for name, structure in structures.items()\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"msas = {\n",
|
||||
" name: read_msa(f\"data/{name.lower()}_1_A.a3m\")\n",
|
||||
" for name in PDB_IDS\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"sequences = {\n",
|
||||
" name: msa[0] for name, msa in msas.items()\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4ce64f18",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### ESM-2 predictions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1f2da88f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Evaluate predictions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0adb0a11",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This loop generates contact predictions for each protein using ESM-2, compares them against the experimentally determined structures, and computes precision metrics across different sequence separation ranges to evaluate model performance."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "941b4afa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"predictions = {}\n",
|
||||
"results = []\n",
|
||||
"\n",
|
||||
"for pdb_id in sequences:\n",
|
||||
" _, sequence = sequences[pdb_id]\n",
|
||||
" prediction = predict_contacts(sequence, model, tokenizer)\n",
|
||||
" predictions[pdb_id] = prediction[0]\n",
|
||||
" \n",
|
||||
" true_contacts = mx.array(contacts[pdb_id])\n",
|
||||
" \n",
|
||||
" min_len = min(prediction.shape[1], true_contacts.shape[0])\n",
|
||||
" pred_trimmed = prediction[:, :min_len, :min_len]\n",
|
||||
" true_trimmed = true_contacts[:min_len, :min_len]\n",
|
||||
" true_trimmed = mx.expand_dims(true_trimmed, axis=0)\n",
|
||||
" \n",
|
||||
" metrics = evaluate_prediction(pred_trimmed, true_trimmed)\n",
|
||||
" result = {\"id\": pdb_id, \"model\": \"ESM-2 (Unsupervised)\"}\n",
|
||||
" result.update(metrics)\n",
|
||||
" results.append(result)\n",
|
||||
"\n",
|
||||
"results_df = pd.DataFrame(results)\n",
|
||||
"display(results_df)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c5c7418a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The results demonstrate that ESM-2 excels at predicting long-range contacts, with precision scores ranging from 40.9% to 86.4% for residues more than 24 positions apart. Performance is consistently higher for distant contacts compared to local ones. For example, the universal stress protein (5ahw) achieves 86.4% precision for long-range contacts but only 2.4% for local contacts between 3 and 6 residues apart. This trend is observed across all three proteins, with medium-range contacts (12–24 residues apart) and short-range contacts (6–12 residues apart) showing intermediate accuracy. These results suggest that ESM-2 has learned to identify evolutionarily conserved structural motifs that connect distant regions of the sequence, which are often critical for protein fold stability and function."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "487cff51",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Plot contacts and predictions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "10291191",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This analysis generates contact map visualizations for all three proteins, presenting ESM-2’s predictions as heatmaps and overlaying the true experimental contacts as colored dots."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "628efc10",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"proteins = [r['id'] for r in results]\n",
|
||||
"fig, axes = plt.subplots(figsize=(6 * len(proteins), 6), ncols=len(proteins))\n",
|
||||
"if len(proteins) == 1:\n",
|
||||
" axes = [axes]\n",
|
||||
"\n",
|
||||
"for ax, pdb_id in zip(axes, proteins):\n",
|
||||
" prediction = predictions[pdb_id]\n",
|
||||
" target = contacts[pdb_id]\n",
|
||||
" \n",
|
||||
" result = next(r for r in results if r['id'] == pdb_id)\n",
|
||||
" long_pl = result['long_P@L']\n",
|
||||
" \n",
|
||||
" plot_contacts_and_predictions(\n",
|
||||
" prediction, target, ax=ax, \n",
|
||||
" title=f\"{pdb_id}: Long Range P@L: {100 * long_pl:.1f}%\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "99e1edaf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The contact maps highlight ESM-2’s strong ability to detect long-range structural relationships. In each panel, the lower triangle shows model predictions, where darker blue regions indicate high-confidence contacts, and the upper triangle shows the corresponding experimental data. Correct predictions appear as blue dots, forming distinct off-diagonal patterns in 5ahw and 1a3a that capture key global fold interactions. Red dots mark false positives, which are relatively rare, while grey dots represent missed contacts. These missed contacts are notably more frequent in 1xcr, consistent with its lower long-range precision. The dense clusters of blue true positives in 5ahw, compared to the sparser, fragmented patterns in 1xcr, clearly illustrate the variation in predictive performance across proteins."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
334
esm/notebooks/embeddings.ipynb
Normal file
334
esm/notebooks/embeddings.ipynb
Normal file
File diff suppressed because one or more lines are too long
662
esm/notebooks/mutation_effect_prediction.ipynb
Normal file
662
esm/notebooks/mutation_effect_prediction.ipynb
Normal file
File diff suppressed because one or more lines are too long
12
esm/requirements.txt
Normal file
12
esm/requirements.txt
Normal file
@ -0,0 +1,12 @@
|
||||
mlx
|
||||
torch
|
||||
transformers
|
||||
numpy
|
||||
pandas
|
||||
seaborn
|
||||
biopython
|
||||
biotite
|
||||
scipy
|
||||
tqdm
|
||||
scikit-learn
|
||||
matplotlib
|
121
esm/test.py
Normal file
121
esm/test.py
Normal file
@ -0,0 +1,121 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer, EsmConfig, EsmForMaskedLM
|
||||
|
||||
from esm import ESM2
|
||||
|
||||
# Paths for MLX and Hugging Face versions of ESM-2
|
||||
MLX_PATH = "checkpoints/mlx-esm2_t12_35M_UR50D"
|
||||
HF_PATH = "facebook/esm2_t12_35M_UR50D"
|
||||
|
||||
|
||||
def load_mlx_model():
|
||||
"""Load MLX ESM-2 model and tokenizer."""
|
||||
tokenizer, model = ESM2.from_pretrained(MLX_PATH)
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
def load_hf_model():
|
||||
"""Load Hugging Face ESM-2 model and tokenizer with hidden states + attentions."""
|
||||
tokenizer = AutoTokenizer.from_pretrained(HF_PATH)
|
||||
config = EsmConfig.from_pretrained(
|
||||
HF_PATH, output_hidden_states=True, output_attentions=True
|
||||
)
|
||||
model = EsmForMaskedLM.from_pretrained(HF_PATH, config=config)
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
class TestESM2(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Load both MLX and HF models/tokenizers once for all tests
|
||||
cls.mlx_tokenizer, cls.mlx_model = load_mlx_model()
|
||||
cls.hf_tokenizer, cls.hf_model = load_hf_model()
|
||||
|
||||
def test_tokenizer(self):
|
||||
"""Verify MLX tokenizer matches Hugging Face tokenizer behavior."""
|
||||
self.assertEqual(len(self.mlx_tokenizer), len(self.hf_tokenizer))
|
||||
|
||||
sequences = [
|
||||
"MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
|
||||
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
|
||||
]
|
||||
|
||||
# Compare batched tokenization (padded sequences)
|
||||
mlx_batch = self.mlx_tokenizer.batch_encode(sequences)
|
||||
hf_batch = (
|
||||
self.hf_tokenizer(sequences, return_tensors="pt", padding=True)["input_ids"]
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
self.assertEqual(tuple(mlx_batch.shape), tuple(hf_batch.shape))
|
||||
self.assertTrue(
|
||||
np.array_equal(np.array(mlx_batch.tolist(), dtype=hf_batch.dtype), hf_batch)
|
||||
)
|
||||
|
||||
# Compare single-sequence encode/decode
|
||||
for sequence in sequences:
|
||||
mlx_tokens = self.mlx_tokenizer.encode(sequence)
|
||||
hf_tokens = (
|
||||
self.hf_tokenizer(sequence, return_tensors="pt")["input_ids"]
|
||||
.cpu()
|
||||
.numpy()
|
||||
.tolist()[0]
|
||||
)
|
||||
self.assertTrue(np.array_equal(mlx_tokens, hf_tokens))
|
||||
self.assertEqual(
|
||||
self.mlx_tokenizer.decode(mlx_tokens),
|
||||
self.hf_tokenizer.decode(hf_tokens).replace(" ", ""),
|
||||
)
|
||||
self.assertEqual(
|
||||
self.mlx_tokenizer.decode(mlx_tokens, skip_special_tokens=True),
|
||||
self.hf_tokenizer.decode(hf_tokens, skip_special_tokens=True).replace(
|
||||
" ", ""
|
||||
),
|
||||
)
|
||||
|
||||
def test_model(self):
|
||||
"""Verify MLX and HF model outputs match (logits, hidden states, attentions)."""
|
||||
sequences = [
|
||||
"MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
|
||||
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
|
||||
]
|
||||
for sequence in sequences:
|
||||
# Tokenize
|
||||
mlx_tokens = self.mlx_tokenizer.encode(sequence, return_batch_dim=True)
|
||||
hf_tokens = self.hf_tokenizer(sequence, return_tensors="pt")["input_ids"]
|
||||
|
||||
# Forward pass
|
||||
mlx_outputs = self.mlx_model(
|
||||
mlx_tokens,
|
||||
repr_layers=[self.mlx_model.num_layers],
|
||||
need_head_weights=True,
|
||||
)
|
||||
hf_outputs = self.hf_model(input_ids=hf_tokens)
|
||||
|
||||
# Compare logits
|
||||
mlx_logits = np.array(mlx_outputs["logits"])
|
||||
hf_logits = hf_outputs["logits"].detach().cpu().numpy()
|
||||
self.assertTrue(np.allclose(mlx_logits, hf_logits, rtol=1e-4, atol=1e-4))
|
||||
|
||||
# Compare final-layer hidden states
|
||||
final_layer = self.mlx_model.num_layers
|
||||
mlx_hidden_states = np.array(mlx_outputs["representations"][final_layer])
|
||||
hf_hidden_states = hf_outputs["hidden_states"][-1].detach().cpu().numpy()
|
||||
self.assertTrue(
|
||||
np.allclose(mlx_hidden_states, hf_hidden_states, rtol=1e-4, atol=1e-4)
|
||||
)
|
||||
|
||||
# Compare attentions for final layer
|
||||
mlx_attentions = np.array(
|
||||
mlx_outputs["attentions"][:, final_layer - 1, :, :, :]
|
||||
)
|
||||
hf_attentions = hf_outputs["attentions"][-1].detach().cpu().numpy()
|
||||
self.assertTrue(
|
||||
np.allclose(mlx_attentions, hf_attentions, rtol=1e-4, atol=1e-4)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user