This commit is contained in:
Vincent Amato
2025-08-15 23:48:57 -04:00
parent 4b2a0df237
commit eccabdd227
21 changed files with 13729 additions and 0 deletions

157
esm/README.md Normal file
View File

@@ -0,0 +1,157 @@
# ESM-2
This repository provides an implementation of Meta's ESM-2 protein language model
in MLX.[^1] ESM-2 is Metas second-generation Evolutionary Scale Model, a
transformer-based protein language model trained on millions of diverse protein
sequences with a masked language modeling objective.
![Example contact prediction map](assets/contact_prediction.png)
_Example contact prediction map for a universal stress protein. In this case, ESM-2 650M achieves 86.4% precision at long-range contacts._
## Setup
Install the requirements:
```bash
pip install -r requirements.txt
```
## Usage
Below are the available ESM-2 models:
| Model | Parameters | Layers |
|-------|------------|--------|
| [`esm2_t6_8M_UR50D`](https://huggingface.co/facebook/esm2_t6_8M_UR50D) | 8M | 6 |
| [`esm2_t12_35M_UR50D`](https://huggingface.co/facebook/esm2_t12_35M_UR50D) | 35M | 12 |
| [`esm2_t30_150M_UR50D`](https://huggingface.co/facebook/esm2_t30_150M_UR50D) | 150M | 30 |
| [`esm2_t33_650M_UR50D`](https://huggingface.co/facebook/esm2_t33_650M_UR50D) | 650M | 33 |
| [`esm2_t36_3B_UR50D`](https://huggingface.co/facebook/esm2_t36_3B_UR50D) | 3B | 36 |
| [`esm2_t48_15B_UR50D`](https://huggingface.co/facebook/esm2_t48_15B_UR50D) | 15B | 48 |
Convert a model to MLX format:
```bash
python convert.py --hf-path facebook/esm2_t33_650M_UR50D
```
This will save the converted model in a checkpoints directory.
### Basic Inference
```python
from esm import ESM2
# Load model and tokenizer
tokenizer, model = ESM2.from_pretrained("checkpoints/mlx-esm2_t33_650M_UR50D")
# Example protein sequence (human insulin)
sequence = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"
# Tokenize and run inference
tokens = tokenizer.encode(sequence)
result = model(tokens)
logits = result["logits"] # Shape: (batch, length, vocab_size)
```
### Masked Language Modeling
```python
# For a complete example, see main.py
python main.py --sequence "YOUR_SEQUENCE" --mask-position 50
```
### Embeddings
```python
# Get sequence-level representations
seq_repr = model.get_sequence_representations(tokens, layer=-1) # Shape: (batch, embed_dim)
# Extract per-residue representations from specific layers
representations = model.extract_features(tokens, repr_layers=[20, 30, 33])
final_layer = representations[33] # Shape: (batch, length, embed_dim)
```
### Contact Prediction
```python
# Predict residue-residue contacts
contacts = model.predict_contacts(tokens) # Shape: (batch, length, length)
# Or get contacts along with other outputs
result = model(tokens, return_contacts=True)
contacts = result["contacts"]
attentions = result["attentions"] # Shape: (batch, layers, heads, length, length)
```
### Examples
**Mutation Effect Prediction**: [notebooks/mutation_effect_prediction.ipynb](notebooks/mutation_effect_prediction.ipynb)
This notebook demonstrates how to use ESM-2 for zero-shot mutation effect prediction by scoring amino acid substitutions based on their likelihood under the model. We validate the approach using experimental fitness data from β-lactamase TEM, showing how ESM-2 captures functional constraints without requiring structural information.
**Embeddings**: [notebooks/embeddings.ipynb](notebooks/embeddings.ipynb)
This notebook explores how ESM-2 generates meaningful protein embeddings that capture evolutionary and functional relationships between proteins. We analyze six diverse human proteins to demonstrate how the learned representations cluster proteins by function and reveal biological similarities.
**Contact Prediction**: [notebooks/contact_prediction.ipynb](notebooks/contact_prediction.ipynb)
This notebook shows how to predict residue-residue contacts in protein structures using ESM-2's attention patterns. We evaluate contact prediction performance on three diverse proteins, demonstrating how the model captures both local and long-range structural relationships directly from sequence data.
### Benchmarking
Benchmark MLX performance:
```bash
python benchmarks/benchmark_mx.py
```
Benchmark PyTorch MPS performance:
```bash
python benchmarks/benchmark_pt.py
```
Expected performance on M4 MacBook Pro (batch_size = 5):
- MLX: 299 ms per step, 16.71 sequences/sec
- PyTorch MPS: 402 ms per step, 12.43 sequences/sec
### Testing
Verify correctness against original implementation:
```bash
python test.py
```
This tests tokenizer and model outputs (logits, hidden states, and attentions) for equivalence with the original implementation.
### Citations:
```bibtex
@article{rives2019biological,
author={Rives, Alexander and Meier, Joshua and Sercu, Tom and Goyal, Siddharth and Lin, Zeming and Liu, Jason and Guo, Demi and Ott, Myle and Zitnick, C. Lawrence and Ma, Jerry and Fergus, Rob},
title={Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences},
year={2019},
doi={10.1101/622803},
url={https://www.biorxiv.org/content/10.1101/622803v4},
journal={PNAS}
}
```
```bibtex
@article{lin2023evolutionary,
title={Evolutionary-scale prediction of atomic-level protein structure with a language model},
author={Lin, Zeming and Akin, Halil and Rao, Roshan and Hie, Brian and Zhu, Ziheng and Lu, Wenting and Smetanin, Nikita and Verkuil, Robert and Kabeli, Ori and Shmueli, Yilun and dos Santos Costa, Allan and Fazel-Zarandi, Maryam and Sercu, Tom and Candido, Salvatore and Rives, Alexander},
journal={Science},
volume={379},
number={6637},
pages={1123--1130},
year={2023},
publisher={American Association for the Advancement of Science}
}
```
[^1]: Refer to the [paper](https://www.science.org/doi/10.1126/science.ade2574) and [code](https://github.com/facebookresearch/esm) for more details.

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

View File

@@ -0,0 +1,47 @@
import sys
import time
from pathlib import Path
import mlx.core as mx
# Add parent directory to Python path
cur_path = Path(__file__).parents[1].resolve()
sys.path.append(str(cur_path))
from esm import ESM2
# Example protein sequence (Green Fluorescent Protein)
protein_sequence = "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK"
# Load pretrained ESM-2 model and its tokenizer from local checkpoint
tokenizer, model = ESM2.from_pretrained("checkpoints/mlx-esm2_t33_650M_UR50D")
# Number of sequences to process in each forward pass
batch_size = 5
# Number of timing iterations for performance measurement
steps = 50
# Tokenize the protein sequence into integer IDs for the model
# Replicate the same sequence 'batch_size' times to create a batch
tokens = tokenizer.batch_encode([protein_sequence] * batch_size)
# Warm-up phase
for _ in range(10):
result = model(tokens)
mx.eval(result["logits"]) # Force computation to complete
# Measure average inference time over 'steps' iterations
tic = time.time()
for _ in range(steps):
result = model(tokens)
mx.eval(result["logits"]) # Synchronize and ensure computation finishes
toc = time.time()
# Compute metrics: average time per step (ms) and throughput (sequences/sec)
ms_per_step = 1000 * (toc - tic) / steps
throughput = batch_size * 1000 / ms_per_step
# Display results
print(f"Time (ms) per step: {ms_per_step:.3f}")
print(f"Throughput: {throughput:.2f} sequences/sec")

View File

@@ -0,0 +1,52 @@
import time
import torch
from transformers import AutoTokenizer, EsmForMaskedLM
# Example protein sequence (Green Fluorescent Protein)
protein_sequence = "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK"
# Hugging Face model identifier for ESM-2 (33 layers, 650M params, UR50D training set)
model_name = "facebook/esm2_t33_650M_UR50D"
# Load tokenizer and model; move model to Apple Metal Performance Shaders (MPS) device
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = EsmForMaskedLM.from_pretrained(model_name).to("mps")
# Number of sequences per forward pass
batch_size = 5
# Number of timing iterations
steps = 50
# Tokenize input sequence and replicate for the batch
# Replicate the same sequence 'batch_size' times to create a batch
inputs = tokenizer(
[protein_sequence] * batch_size,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024
)
input_ids = inputs["input_ids"].to("mps")
attention_mask = inputs["attention_mask"].to("mps")
# Warm-up phase
for _ in range(10):
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
torch.mps.synchronize() # Ensure all queued ops on MPS are complete before next step
# Timed inference loop
tic = time.time()
for _ in range(steps):
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
torch.mps.synchronize() # Wait for computation to finish before timing next iteration
toc = time.time()
# Compute performance metrics
ms_per_step = 1000 * (toc - tic) / steps
throughput = batch_size * 1000 / ms_per_step
# Report results
print(f"Time (ms) per step: {ms_per_step:.3f}")
print(f"Throughput: {throughput:.2f} sequences/sec")

178
esm/convert.py Normal file
View File

@@ -0,0 +1,178 @@
import argparse
import json
import shutil
from pathlib import Path
from typing import Dict
import mlx.core as mx
import torch
from huggingface_hub import snapshot_download
def download(hf_repo: str) -> Path:
"""Download model from Hugging Face."""
return Path(
snapshot_download(
repo_id=hf_repo,
allow_patterns=["*.safetensors", "*.json", "*.bin", "*.txt"],
)
)
def remap_key(key: str) -> str:
"""Remap HuggingFace ESM key names to MLX format."""
# Skip position embeddings and position_ids
if "position_embeddings" in key or "position_ids" in key:
return None
# Map lm_head components properly
if key == "lm_head.decoder.weight":
return "lm_head.weight"
if key == "lm_head.decoder.bias":
return "lm_head.bias"
if key == "lm_head.dense.weight":
return "lm_head.dense.weight"
if key == "lm_head.dense.bias":
return "lm_head.dense.bias"
if key == "lm_head.layer_norm.weight":
return "lm_head.layer_norm.weight"
if key == "lm_head.layer_norm.bias":
return "lm_head.layer_norm.bias"
# Core remapping patterns
key = key.replace("esm.embeddings.word_embeddings", "embed_tokens")
key = key.replace("esm.encoder.emb_layer_norm_after", "emb_layer_norm_after")
key = key.replace("esm.encoder.layer.", "layer_")
key = key.replace("esm.contact_head", "contact_head")
key = key.replace("lm_head", "lm_head")
# Attention patterns
key = key.replace(".attention.self.", ".self_attn.")
key = key.replace(".attention.output.dense", ".self_attn.out_proj")
key = key.replace(".attention.LayerNorm", ".self_attn_layer_norm")
key = key.replace(".query", ".q_proj")
key = key.replace(".key", ".k_proj")
key = key.replace(".value", ".v_proj")
key = key.replace(".rotary_embeddings", ".rot_emb")
# FFN patterns
key = key.replace(".intermediate.dense", ".fc1")
key = key.replace(".output.dense", ".fc2")
key = key.replace(".LayerNorm", ".final_layer_norm")
return key
def load_weights(model_path: Path) -> Dict:
"""Load weights from safetensors or PyTorch bin files."""
# Check for safetensors file
safetensors_path = model_path / "model.safetensors"
if safetensors_path.exists():
print("Loading from safetensors...")
return mx.load(str(safetensors_path))
# Check for single bin file
single_bin_path = model_path / "pytorch_model.bin"
if single_bin_path.exists():
print("Loading from pytorch_model.bin...")
state_dict = torch.load(str(single_bin_path), map_location="cpu")
return {k: v.numpy() for k, v in state_dict.items()}
# Check for sharded bin files
index_file = model_path / "pytorch_model.bin.index.json"
if index_file.exists():
print("Loading from sharded bin files...")
with open(index_file) as f:
index = json.load(f)
# Get unique shard files
shard_files = set(index["weight_map"].values())
# Load all shards
state_dict = {}
for shard_file in sorted(shard_files):
print(f" Loading shard: {shard_file}")
shard_path = model_path / shard_file
shard_dict = torch.load(str(shard_path), map_location="cpu")
state_dict.update(shard_dict)
return {k: v.numpy() for k, v in state_dict.items()}
raise ValueError(f"No model weights found in {model_path}")
def convert(model_path: Path) -> Dict[str, mx.array]:
"""Convert ESM weights to MLX format."""
# Load weights from any format
weights = load_weights(model_path)
# Convert keys and create MLX arrays
mlx_weights = {}
for key, value in weights.items():
mlx_key = remap_key(key)
if mlx_key is not None:
mlx_weights[mlx_key] = mx.array(value) if not isinstance(value, mx.array) else value
# If lm_head.weight is missing but embed_tokens.weight exists, set up weight sharing
# (This is for smaller models that don't have a separate lm_head.decoder.weight)
if "lm_head.weight" not in mlx_weights and "embed_tokens.weight" in mlx_weights:
mlx_weights["lm_head.weight"] = mlx_weights["embed_tokens.weight"]
return mlx_weights
def main():
parser = argparse.ArgumentParser(description="Convert ESM weights to MLX format")
parser.add_argument(
"--hf-path",
default="facebook/esm2_t6_8M_UR50D",
help="Hugging Face model path"
)
parser.add_argument(
"--mlx-path",
default=None,
help="Output path for MLX model"
)
parser.add_argument(
"--checkpoints-dir",
default="checkpoints",
help="Directory to save checkpoints (default: checkpoints)"
)
args = parser.parse_args()
# Download model
print(f"Downloading {args.hf_path}...")
model_path = download(args.hf_path)
# Set output path
if args.mlx_path is None:
model_name = args.hf_path.split("/")[-1]
checkpoints_dir = Path(args.checkpoints_dir)
checkpoints_dir.mkdir(parents=True, exist_ok=True)
args.mlx_path = checkpoints_dir / f"mlx-{model_name}"
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
# Convert weights
print("Converting weights...")
mlx_weights = convert(model_path)
# Save weights
print(f"Saving MLX weights to {mlx_path}...")
mx.save_safetensors(str(mlx_path / "model.safetensors"), mlx_weights)
# Copy config and other files
print("Copying config...")
shutil.copy(model_path / "config.json", mlx_path / "config.json")
for file_name in ["special_tokens_map.json", "tokenizer.json", "vocab.txt"]:
src_file = model_path / file_name
if src_file.exists():
shutil.copy(src_file, mlx_path / file_name)
print(f"Conversion complete! MLX model saved to {mlx_path}")
if __name__ == "__main__":
main()

19
esm/esm/__init__.py Normal file
View File

@@ -0,0 +1,19 @@
"""
ESM-2 protein language model implementation in MLX
"""
from .model import ESM2
from .tokenizer import ProteinTokenizer
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
from .attention import MultiheadAttention
from .rotary_embedding import RotaryEmbedding
__all__ = [
'ESM2',
'ProteinTokenizer',
'ContactPredictionHead',
'RobertaLMHead',
'TransformerLayer',
'MultiheadAttention',
'RotaryEmbedding'
]

150
esm/esm/attention.py Normal file
View File

@@ -0,0 +1,150 @@
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .rotary_embedding import RotaryEmbedding
class MultiheadAttention(nn.Module):
"""
Multi-head attention layer with rotary position embeddings, as used in ESM-2.
This module implements both self-attention (when `key` and `value` are not
provided) and cross-attention. It projects input sequences into queries,
keys, and values, applies rotary position embeddings to encode relative
position information, computes scaled dot-product attention over multiple
heads in parallel, and returns a combined output projection.
Args:
embed_dim (int): Total embedding dimension of the model input and output.
num_heads (int): Number of parallel attention heads. Must divide `embed_dim`.
"""
def __init__(
self,
embed_dim,
num_heads,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim**-0.5
# Linear projections for queries, keys, and values (with bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
# Linear projection for output (with bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
# ESM-2 uses rotary embeddings
self.rot_emb = RotaryEmbedding(dim=self.head_dim)
def __call__(
self,
query,
key: Optional[mx.array] = None,
value: Optional[mx.array] = None,
key_padding_mask: Optional[mx.array] = None,
attn_mask: Optional[mx.array] = None,
need_head_weights: bool = False,
) -> Tuple[mx.array, Optional[mx.array]]:
"""
Multi-head attention forward pass.
Args:
query: Tensor of shape (tgt_len, batch, embed_dim).
key: Optional tensor of shape (src_len, batch, embed_dim). Defaults to `query`.
value: Optional tensor of shape (src_len, batch, embed_dim). Defaults to `query`.
key_padding_mask: Optional mask of shape (batch, src_len) to ignore padded positions.
attn_mask: Optional mask for attention (e.g., causal mask).
need_head_weights: If True, return attention weights for each head separately.
Returns:
attn_output: Tensor of shape (tgt_len, batch, embed_dim).
attn_weights_out: Attention weights of shape
(num_heads, batch, tgt_len, src_len) if per-head,
or (batch, tgt_len, src_len) if averaged.
"""
tgt_len, bsz, embed_dim = query.shape
assert embed_dim == self.embed_dim
# For self-attention, use query as key and value if not provided
if key is None:
key = query
if value is None:
value = query
# Project queries, keys, values
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q = q * self.scaling
# Reshape for multi-head attention
q = q.reshape(tgt_len, bsz * self.num_heads, self.head_dim).swapaxes(0, 1)
k = k.reshape(-1, bsz * self.num_heads, self.head_dim).swapaxes(0, 1)
v = v.reshape(-1, bsz * self.num_heads, self.head_dim).swapaxes(0, 1)
src_len = k.shape[1]
# Apply rotary embeddings if present
if self.rot_emb:
q, k = self.rot_emb(q, k)
# Compute attention weights
attn_weights = q @ k.swapaxes(-2, -1)
assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
# Apply attention mask
if attn_mask is not None:
attn_mask = mx.expand_dims(attn_mask, 0)
attn_weights = attn_weights + attn_mask
# Apply key padding mask
if key_padding_mask is not None:
attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len)
# Convert key_padding_mask to boolean and expand dimensions
# key_padding_mask: [bsz, src_len] -> [bsz, 1, 1, src_len]
mask = mx.expand_dims(mx.expand_dims(key_padding_mask.astype(mx.bool_), 1), 2)
# Apply mask: set attention to -inf where mask is True (padded positions)
attn_weights = mx.where(mask, -mx.inf, attn_weights)
attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len)
# Apply softmax
attn_weights_float = mx.softmax(attn_weights.astype(mx.float32), axis=-1)
attn_probs = attn_weights_float
# Compute attention output
attn = attn_probs @ v
assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
# Reshape output
attn = attn.swapaxes(0, 1).reshape(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
# Return attention weights if requested
attn_weights_out: Optional[mx.array] = None
if need_head_weights:
# Return attention weights for each head separately
attn_weights_out = attn_weights_float.reshape(
bsz, self.num_heads, tgt_len, src_len
).astype(attn.dtype).swapaxes(0, 1)
else:
# Return averaged attention weights
attn_weights_out = mx.mean(
attn_weights_float.reshape(bsz, self.num_heads, tgt_len, src_len),
axis=1
).astype(attn.dtype)
return attn, attn_weights_out

336
esm/esm/model.py Normal file
View File

@@ -0,0 +1,336 @@
from typing import List, Dict, Optional, Tuple
from pathlib import Path
import json
import mlx.core as mx
import mlx.nn as nn
from .tokenizer import ProteinTokenizer
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
class ESM2(nn.Module):
"""
ESM-2 protein language model in MLX.
Args:
num_layers (int): Number of transformer layers.
embed_dim (int): Embedding dimension.
attention_heads (int): Number of attention heads.
tokenizer (Optional[ProteinTokenizer]): Tokenizer to use (created if None).
token_dropout (bool): Apply token-dropout masking behavior.
"""
def __init__(
self,
num_layers: int = 33,
embed_dim: int = 1280,
attention_heads: int = 20,
tokenizer: Optional[ProteinTokenizer] = None,
token_dropout: bool = True,
):
super().__init__()
self.num_layers = num_layers
self.embed_dim = embed_dim
self.attention_heads = attention_heads
# Initialize tokenizer
if tokenizer is None:
tokenizer = ProteinTokenizer()
self.tokenizer = tokenizer
self.vocab_size = len(tokenizer)
# Special token IDs / config
self.padding_idx = tokenizer.pad_id
self.mask_idx = tokenizer.mask_id
self.cls_idx = tokenizer.cls_id
self.eos_idx = tokenizer.eos_id
self.prepend_bos = tokenizer.prepend_bos
self.append_eos = tokenizer.append_eos
self.token_dropout = token_dropout
self._init_submodules()
def _init_submodules(self) -> None:
"""Initialize embeddings, transformer stack, and output heads."""
self.embed_scale = 1
# Token embeddings
self.embed_tokens = nn.Embedding(self.vocab_size, self.embed_dim)
# Transformer layers (register each layer so MLX tracks parameters)
self._layer_indices = list(range(self.num_layers))
for i in self._layer_indices:
layer = TransformerLayer(
self.embed_dim,
4 * self.embed_dim, # FFN dimension = 4×embed_dim
self.attention_heads,
)
setattr(self, f"layer_{i}", layer)
# Contact prediction head (uses all layers × heads attentions)
self.contact_head = ContactPredictionHead(
self.num_layers * self.attention_heads,
self.prepend_bos,
self.append_eos,
eos_idx=self.eos_idx,
)
# Final norm + LM head (tied weights)
self.emb_layer_norm_after = nn.LayerNorm(self.embed_dim)
self.lm_head = RobertaLMHead(
embed_dim=self.embed_dim,
output_dim=self.vocab_size,
weight=self.embed_tokens.weight,
)
def __call__(
self,
tokens: mx.array,
repr_layers: List[int] = [],
need_head_weights: bool = False,
return_contacts: bool = False,
) -> Dict[str, mx.array]:
"""
Forward pass through ESM-2.
Args:
tokens: Tensor of token IDs with shape (B, T).
repr_layers: Layers to return hidden states from (0..num_layers).
need_head_weights: If True, return attention weights.
return_contacts: If True, compute residue-residue contact probabilities.
Returns:
dict with:
logits: (B, T, V)
representations: {layer_idx: (B, T, E)}
attentions: (B, L, H, T, T) if requested
contacts: (B, T', T') if requested
"""
if return_contacts:
need_head_weights = True
# Ensure tokens is 2D (B, T)
if tokens.ndim == 1:
tokens = mx.expand_dims(tokens, axis=0)
assert tokens.ndim == 2
# Padding mask (B, T)
padding_mask = mx.equal(tokens, self.padding_idx)
# Embed tokens (B, T, E)
x = self.embed_scale * self.embed_tokens(tokens)
# Token dropout: zero masked tokens + rescale based on observed mask ratio
if self.token_dropout:
mask_positions = mx.equal(tokens, self.mask_idx)
x = mx.where(mask_positions[:, :, None], mx.zeros_like(x), x)
mask_ratio_train = 0.15 * 0.8
src_lengths = mx.sum(~padding_mask, axis=-1, keepdims=True)
mask_ratio_observed = mx.sum(mask_positions, axis=-1, keepdims=True) / src_lengths
scale_factor = (1 - mask_ratio_train) / mx.maximum(1 - mask_ratio_observed, 1e-8)
x = x * scale_factor[:, None, :]
# Zero out padding positions
if padding_mask.any():
x = x * (1 - padding_mask[:, :, None].astype(x.dtype))
# Track requested representations
repr_layers = set(repr_layers)
hidden_representations: Dict[int, mx.array] = {}
if 0 in repr_layers:
hidden_representations[0] = x
if need_head_weights:
attn_weights: List[mx.array] = []
# (B, T, E) -> (T, B, E) for transformer layers
x = mx.swapaxes(x, 0, 1)
# If no padding anywhere, skip the mask
if not padding_mask.any():
padding_mask = None
# Transformer stack
for layer_idx in self._layer_indices:
layer = getattr(self, f"layer_{layer_idx}")
x, attn = layer(
x,
self_attn_padding_mask=padding_mask,
need_head_weights=need_head_weights,
)
# Save hidden representation if requested (store back as (B, T, E))
if (layer_idx + 1) in repr_layers:
hidden_representations[layer_idx + 1] = mx.swapaxes(x, 0, 1)
# Save per-layer attentions if requested (H, B, T, T) -> (B, H, T, T)
if need_head_weights:
attn_weights.append(mx.swapaxes(attn, 0, 1))
# Final layer norm, back to (B, T, E)
x = self.emb_layer_norm_after(x)
x = mx.swapaxes(x, 0, 1)
# Save final hidden if requested
if (layer_idx + 1) in repr_layers:
hidden_representations[layer_idx + 1] = x
# Language modeling logits
x = self.lm_head(x)
# Build result dict
result: Dict[str, mx.array] = {
"logits": x,
"representations": hidden_representations,
}
# Collect attentions and optional contacts
if need_head_weights:
# Stack layers -> (B, L, H, T, T)
attentions = mx.stack(attn_weights, axis=1)
# Mask out padded positions if present
if padding_mask is not None:
attention_mask = 1 - padding_mask.astype(attentions.dtype)
attention_mask = mx.expand_dims(attention_mask, 1) * mx.expand_dims(attention_mask, 2)
attentions = attentions * attention_mask[:, None, None, :, :]
result["attentions"] = attentions
if return_contacts:
contacts = self.contact_head(tokens, attentions)
result["contacts"] = contacts
return result
def predict_contacts(self, tokens: mx.array) -> mx.array:
"""
Predict residue-residue contacts.
Args:
tokens: Tensor of shape (B, T).
Returns:
mx.array: Contact probabilities of shape (B, T', T').
"""
return self(tokens, return_contacts=True)["contacts"]
def extract_features(
self,
tokens: mx.array,
repr_layers: Optional[List[int]] = None,
return_all_hiddens: bool = False,
) -> Dict[int, mx.array]:
"""
Extract hidden representations from selected layers.
Args:
tokens: Tensor of shape (B, T).
repr_layers: Layer indices to return (default: last layer).
return_all_hiddens: If True, return all layers (0..num_layers).
Returns:
dict: {layer_idx: (B, T, E)} for requested layers.
"""
if return_all_hiddens:
repr_layers = list(range(self.num_layers + 1))
elif repr_layers is None:
repr_layers = [self.num_layers]
result = self(tokens, repr_layers=repr_layers)
return result["representations"]
def get_sequence_representations(
self,
tokens: mx.array,
layer: int = -1,
) -> mx.array:
"""
Average token representations into a per-sequence embedding.
Args:
tokens: Tensor of shape (B, T).
layer: Layer index to use (-1 or num_layers for last).
Returns:
mx.array: Sequence embeddings of shape (B, E).
"""
if layer == -1:
layer = self.num_layers
representations = self.extract_features(tokens, repr_layers=[layer])
repr = representations[layer]
# Mask: non-padding and not CLS; optionally not EOS
mask = mx.logical_and(
mx.not_equal(tokens, self.padding_idx),
mx.not_equal(tokens, self.cls_idx),
)
if self.append_eos:
mask = mx.logical_and(mask, mx.not_equal(tokens, self.eos_idx))
# Mean over valid positions
mask = mask[:, :, None].astype(repr.dtype)
masked_repr = repr * mask
seq_lens = mx.sum(mask, axis=1, keepdims=True)
seq_repr = mx.sum(masked_repr, axis=1) / mx.maximum(seq_lens[:, :, 0], 1.0)
return seq_repr
@classmethod
def from_pretrained(cls, model_path: str) -> Tuple[ProteinTokenizer, "ESM2"]:
"""
Load model weights and config from a directory.
Expects:
- config.json
- model.safetensors
- vocab.txt (optional, will use default if not found)
- special_tokens_map.json (optional, will use default if not found)
Args:
model_path: Path to directory with weights and config.
Returns:
(tokenizer, model): Initialized tokenizer and ESM2 model.
"""
model_dir = Path(model_path)
config_path = model_dir / "config.json"
with open(config_path, "r") as f:
config = json.load(f)
# Check for vocab and special tokens files
vocab_path = model_dir / "vocab.txt"
special_tokens_path = model_dir / "special_tokens_map.json"
if vocab_path.exists() and special_tokens_path.exists():
tokenizer = ProteinTokenizer(
vocab_file=str(vocab_path),
special_tokens_map_file=str(special_tokens_path),
)
else:
tokenizer = ProteinTokenizer()
model = cls(
num_layers=config["num_hidden_layers"],
embed_dim=config["hidden_size"],
attention_heads=config["num_attention_heads"],
tokenizer=tokenizer,
token_dropout=config["token_dropout"],
)
# Load safetensors as nested dict and update model params
weights_path = model_dir / "model.safetensors"
flat_weights = mx.load(str(weights_path))
nested_weights: Dict[str, dict] = {}
for key, value in flat_weights.items():
parts = key.split(".")
cur = nested_weights
for p in parts[:-1]:
cur = cur.setdefault(p, {})
cur[parts[-1]] = value
model.update(nested_weights)
return tokenizer, model

212
esm/esm/modules.py Normal file
View File

@@ -0,0 +1,212 @@
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from .attention import MultiheadAttention
def symmetrize(x: mx.array) -> mx.array:
"""
Make a tensor symmetric over its last two dimensions.
Args:
x: Tensor with shape (..., L, L).
Returns:
mx.array: Symmetrized tensor of shape (..., L, L).
"""
# Add tensor to its transpose over the last two dims
return x + mx.swapaxes(x, -1, -2)
def apc(x: mx.array) -> mx.array:
"""
Apply Average Product Correction (APC) to remove background co-variation.
Args:
x: Tensor with shape (..., L, L).
Returns:
mx.array: APC-corrected tensor of shape (..., L, L).
"""
# Compute row, column, and total sums
a1 = mx.sum(x, axis=-1, keepdims=True)
a2 = mx.sum(x, axis=-2, keepdims=True)
a12 = mx.sum(x, axis=(-1, -2), keepdims=True)
# Expected co-variation under independence
expected = (a1 * a2) / a12
return x - expected
class TransformerLayer(nn.Module):
"""
Transformer layer used in ESM-2.
Args:
embed_dim (int): Model embedding dimension.
ffn_embed_dim (int): Hidden dimension of the feed-forward network.
attention_heads (int): Number of attention heads.
"""
def __init__(
self,
embed_dim: int,
ffn_embed_dim: int,
attention_heads: int,
):
super().__init__()
self.embed_dim = embed_dim
self.ffn_embed_dim = ffn_embed_dim
self.attention_heads = attention_heads
self._init_submodules()
def _init_submodules(self) -> None:
"""Initialize attention, norms, and feed-forward submodules."""
self.self_attn = MultiheadAttention(self.embed_dim, self.attention_heads)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim)
self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def __call__(
self,
x: mx.array,
self_attn_mask: Optional[mx.array] = None,
self_attn_padding_mask: Optional[mx.array] = None,
need_head_weights: bool = False,
):
"""
Forward pass for the Transformer layer.
Args:
x: Tensor of shape (seq_len, batch, embed_dim).
self_attn_mask: Optional attention mask.
self_attn_padding_mask: Optional padding mask of shape (batch, seq_len).
need_head_weights: If True, return per-head attention weights.
Returns:
x: Tensor of shape (seq_len, batch, embed_dim).
attn: Attention weights of shape
(num_heads, batch, tgt_len, src_len) if per-head,
or (batch, tgt_len, src_len) if averaged.
"""
# Self-attention block
residual = x
x = self.self_attn_layer_norm(x)
x, attn = self.self_attn(
query=x,
key_padding_mask=self_attn_padding_mask,
attn_mask=self_attn_mask,
need_head_weights=need_head_weights,
)
x = residual + x
# Feed-forward block
residual = x
x = self.final_layer_norm(x)
x = nn.gelu(self.fc1(x))
x = self.fc2(x)
x = residual + x
return x, attn
class RobertaLMHead(nn.Module):
"""
Masked Language Modeling (MLM) head with tied weights.
Args:
embed_dim (int): Embedding dimension of the backbone.
output_dim (int): Vocabulary size.
weight (mx.array): Embedding weight matrix for tied projection.
"""
def __init__(self, embed_dim: int, output_dim: int, weight: mx.array):
super().__init__()
self.dense = nn.Linear(embed_dim, embed_dim)
self.layer_norm = nn.LayerNorm(embed_dim)
self.weight = weight
self.bias = mx.zeros(output_dim)
def __call__(self, features: mx.array) -> mx.array:
"""
Forward pass for the MLM head.
Args:
features: Tensor of shape (seq_len, batch, embed_dim).
Returns:
mx.array: Logits of shape (seq_len, batch, output_dim).
"""
# Transform features before projection to vocab
x = self.dense(features)
x = nn.gelu(x)
x = self.layer_norm(x)
return mx.matmul(x, self.weight.T) + self.bias
class ContactPredictionHead(nn.Module):
"""
Predict residue-residue contact probabilities from attention maps.
Args:
in_features (int): Number of attention channels (layers × heads).
prepend_bos (bool): If True, drop BOS/CLS token attentions.
append_eos (bool): If True, drop EOS token attentions.
bias (bool): Whether the regression layer uses a bias term.
eos_idx (Optional[int]): Token ID for EOS; required if append_eos=True.
"""
def __init__(
self,
in_features: int,
prepend_bos: bool,
append_eos: bool,
bias: bool = True,
eos_idx: Optional[int] = None,
):
super().__init__()
self.in_features = in_features
self.prepend_bos = prepend_bos
self.append_eos = append_eos
if append_eos and eos_idx is None:
raise ValueError("append_eos=True but eos_idx was not provided.")
self.eos_idx = eos_idx
self.regression = nn.Linear(in_features, 1, bias=bias)
def __call__(self, tokens: mx.array, attentions: mx.array) -> mx.array:
"""
Forward pass for contact prediction.
Args:
tokens: Tensor of shape (B, T).
attentions: Tensor of shape (B, L, H, T, T).
Returns:
mx.array: Contact probabilities of shape (B, T', T'),
where T' = T - [prepend_bos] - [append_eos].
"""
# Remove EOS attentions if requested
if self.append_eos:
eos_mask = mx.not_equal(tokens, self.eos_idx).astype(attentions.dtype)
eos_mask = eos_mask[:, None, :] * eos_mask[:, :, None]
attentions = attentions * eos_mask[:, None, None, :, :]
attentions = attentions[..., :-1, :-1]
# Remove BOS attentions if requested
if self.prepend_bos:
attentions = attentions[..., 1:, 1:]
# Merge (layers × heads) into channel dimension
batch_size, layers, heads, seqlen, _ = attentions.shape
attentions = attentions.reshape(batch_size, layers * heads, seqlen, seqlen)
# Symmetrize and apply APC to enhance contact signal
attentions = apc(symmetrize(attentions))
# Apply logistic regression over channels
attentions = mx.transpose(attentions, axes=[0, 2, 3, 1])
logits = self.regression(attentions)
return nn.sigmoid(mx.squeeze(logits, axis=3))

112
esm/esm/rotary_embedding.py Normal file
View File

@@ -0,0 +1,112 @@
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
def rotate_half(x: mx.array) -> mx.array:
"""
Rotate last dimension by splitting into two halves and swapping.
Args:
x: Tensor with even-sized last dimension.
Returns:
mx.array: Tensor of same shape as `x` with halves rotated.
"""
# Split into two equal halves along the last dimension
x1, x2 = mx.split(x, 2, axis=-1)
# Swap halves and negate the second half
return mx.concatenate((-x2, x1), axis=-1)
def apply_rotary_pos_emb(x: mx.array, cos: mx.array, sin: mx.array) -> mx.array:
"""
Apply rotary position embeddings to a tensor.
Args:
x: Input tensor of shape (..., seq_len, dim).
cos: Cosine embedding table of shape (1, seq_len, dim).
sin: Sine embedding table of shape (1, seq_len, dim).
Returns:
mx.array: Tensor with rotary position embeddings applied.
"""
# Trim cos/sin to match the sequence length of x
cos = cos[:, : x.shape[-2], :]
sin = sin[:, : x.shape[-2], :]
# Elementwise rotation: (x * cos) + (rotate_half(x) * sin)
return (x * cos) + (rotate_half(x) * sin)
class RotaryEmbedding(nn.Module):
"""
Rotary position embedding (RoPE) module.
Args:
dim (int): Head dimension size (must be even).
"""
def __init__(self, dim: int, *_, **__):
super().__init__()
# Precompute inverse frequency for each pair of dimensions
self.inv_freq = 1.0 / (10000 ** (mx.arange(0, dim, 2).astype(mx.float32) / dim))
# Cache for cosine/sine tables to avoid recomputation
self._seq_len_cached = None
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_tables(self, x: mx.array, seq_dimension: int = 1) -> Tuple[mx.array, mx.array]:
"""
Compute and cache cos/sin tables for the given sequence length.
Args:
x: Reference tensor for sequence length.
seq_dimension: Axis containing the sequence length.
Returns:
Tuple of:
cos: Cosine table of shape (1, seq_len, dim).
sin: Sine table of shape (1, seq_len, dim).
"""
seq_len = x.shape[seq_dimension]
# Only update cache if sequence length has changed
if seq_len != self._seq_len_cached:
self._seq_len_cached = seq_len
# Time steps: shape (seq_len,)
t = mx.arange(seq_len).astype(self.inv_freq.dtype)
# Outer product between time and inverse frequency
freqs = mx.einsum("i,j->ij", t, self.inv_freq)
# Duplicate frequencies for cos/sin dimensions
emb = mx.concatenate((freqs, freqs), axis=-1)
self._cos_cached = mx.cos(emb)[None, :, :]
self._sin_cached = mx.sin(emb)[None, :, :]
return self._cos_cached, self._sin_cached
def __call__(self, q: mx.array, k: mx.array) -> Tuple[mx.array, mx.array]:
"""
Apply rotary position embeddings to queries and keys.
Args:
q: Query tensor of shape (..., seq_len, dim).
k: Key tensor of shape (..., seq_len, dim).
Returns:
Tuple of:
q_rot: Query tensor with RoPE applied.
k_rot: Key tensor with RoPE applied.
"""
# Get (and cache) cos/sin tables based on key sequence length
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
k, seq_dimension=-2
)
# Apply rotary embeddings to both q and k
return (
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
)

213
esm/esm/tokenizer.py Normal file
View File

@@ -0,0 +1,213 @@
from typing import List, Sequence, Union, Optional
import json
from pathlib import Path
import mlx.core as mx
# Canonical amino-acid tokens (IUPAC standard + uncommon variants)
PROTEIN_TOKENS = [
"L", "A", "G", "V", "S", "E", "R", "T", "I", "D",
"P", "K", "Q", "N", "F", "Y", "M", "H", "W", "C",
"X", "B", "U", "Z", "O", ".", "-"
]
ArrayLike = Union[List[int], mx.array]
class ProteinTokenizer:
"""
Protein sequence tokenizer compatible with ESM-2.
This class converts protein sequences into token IDs and back, following
the vocabulary, special tokens, and formatting rules used by ESM-2.
"""
def __init__(
self,
vocab_file: Optional[str] = None,
special_tokens_map_file: Optional[str] = None,
):
"""
Initialize the ProteinTokenizer.
Args:
vocab_file: Optional path to a file containing the vocabulary,
one token per line.
special_tokens_map_file: Optional path to a JSON file defining
special token names and values.
If both files are provided, they override the default vocabulary and
special token mappings. Otherwise, defaults are loaded.
"""
# Load vocabulary from files if given, otherwise use built-in defaults
if vocab_file and special_tokens_map_file:
self._load_from_files(vocab_file, special_tokens_map_file)
else:
self._load_default_vocab()
# Create token ↔ ID mappings
self.token_to_id = {tok: i for i, tok in enumerate(self.vocab)}
self.id_to_token = {i: tok for i, tok in enumerate(self.vocab)}
# Cache special token IDs
self.cls_id = self.token_to_id["<cls>"]
self.pad_id = self.token_to_id["<pad>"]
self.eos_id = self.token_to_id["<eos>"]
self.unk_id = self.token_to_id["<unk>"]
self.mask_id = self.token_to_id["<mask>"]
# Behavior flags for ESM-2-style BOS/EOS
self.prepend_bos = True
self.append_eos = True
def _load_from_files(self, vocab_file: str, special_tokens_map_file: str) -> None:
"""Load vocabulary and special tokens from the provided files."""
# Vocabulary file: one token per line
vocab_path = Path(vocab_file)
with open(vocab_path, "r", encoding="utf-8") as f:
self.vocab = [line.strip() for line in f if line.strip()]
# Special tokens mapping file (JSON)
special_tokens_path = Path(special_tokens_map_file)
with open(special_tokens_path, "r", encoding="utf-8") as f:
self.special_tokens_map = json.load(f)
def _load_default_vocab(self) -> None:
"""Load the built-in ESM vocabulary and special token mapping."""
# ESM convention: prepend special tokens, then amino acids, then <mask>
prepend_toks = ["<cls>", "<pad>", "<eos>", "<unk>"]
append_toks = ["<mask>"]
self.vocab = prepend_toks + PROTEIN_TOKENS
# Pad vocab size to multiple of 8 (original implementation detail)
while len(self.vocab) % 8 != 0:
self.vocab.append(f"<null_{len(self.vocab) - len(prepend_toks)}>")
self.vocab.extend(append_toks)
# Default special tokens map
self.special_tokens_map = {
"cls_token": "<cls>",
"pad_token": "<pad>",
"eos_token": "<eos>",
"unk_token": "<unk>",
"mask_token": "<mask>",
}
def encode(
self,
sequence: str,
*,
add_special_tokens: bool = True,
return_batch_dim: bool = False,
dtype=mx.int32,
) -> mx.array:
"""
Convert a protein sequence into token IDs.
Args:
sequence: Protein sequence (case-insensitive).
add_special_tokens: If True, add <cls> at the start and <eos> at the end.
return_batch_dim: If True, output shape will be (1, L) instead of (L,).
dtype: MLX dtype for the returned array.
Returns:
mx.array: Token IDs of shape (L,) or (1, L).
"""
ids: List[int] = []
if add_special_tokens and self.prepend_bos:
ids.append(self.cls_id)
# Map each residue to its ID (defaulting to <unk> if not in vocab)
for ch in sequence.upper():
ids.append(self.token_to_id.get(ch, self.unk_id))
if add_special_tokens and self.append_eos:
ids.append(self.eos_id)
arr = mx.array(ids, dtype=dtype)
return mx.expand_dims(arr, axis=0) if return_batch_dim else arr
def batch_encode(
self,
sequences: Sequence[str],
*,
add_special_tokens: bool = True,
max_length: Optional[int] = None,
dtype=mx.int32,
) -> mx.array:
"""
Encode multiple protein sequences into a padded batch.
Args:
sequences: List/sequence of protein strings.
add_special_tokens: If True, add <cls> and <eos> tokens.
max_length: If provided, truncate sequences to this length before padding.
dtype: MLX dtype for the returned array.
Returns:
mx.array: Tensor of shape (B, L) with right-padding using <pad> IDs.
"""
# Encode each sequence as (L,)
encoded = [
self.encode(s, add_special_tokens=add_special_tokens, dtype=dtype)
for s in sequences
]
encoded = [e if e.ndim == 1 else e[0] for e in encoded]
if max_length is not None:
encoded = [e[:max_length] for e in encoded]
# Find the longest sequence and right-pad all others
max_len = max((int(e.shape[0]) for e in encoded), default=0)
padded = []
for e in encoded:
pad_len = max_len - int(e.shape[0])
if pad_len > 0:
pad = mx.full((pad_len,), self.pad_id, dtype=dtype)
e = mx.concatenate([e, pad], axis=0)
padded.append(e)
return mx.stack(padded, axis=0) if padded else mx.array([], dtype=dtype)
def decode(
self,
token_ids: ArrayLike,
*,
skip_special_tokens: bool = False,
) -> str:
"""
Convert token IDs back into a protein sequence string.
Args:
token_ids: 1-D or 2-D array/list of IDs. If 2-D, only the first row is decoded.
skip_special_tokens: If True, remove all special tokens from output.
Returns:
str: Protein sequence.
"""
# Normalize to a 1-D MLX array
if hasattr(token_ids, "shape") and hasattr(token_ids, "tolist"):
ids = token_ids if token_ids.ndim == 1 else token_ids[0]
else:
ids = mx.array(token_ids, dtype=mx.int32)
ids_list = [int(x) for x in ids.tolist()]
toks: List[str] = []
for i in ids_list:
tok = self.id_to_token.get(i, "<unk>")
if skip_special_tokens and tok in {
"<cls>", "<pad>", "<eos>", "<unk>", "<mask>"
}:
continue
toks.append(tok)
return "".join(toks)
def __len__(self) -> int:
"""Return the size of the tokenizers vocabulary."""
return len(self.vocab)

66
esm/main.py Normal file
View File

@@ -0,0 +1,66 @@
import argparse
import mlx.core as mx
from esm import ESM2
def main():
parser = argparse.ArgumentParser(description="ESM-2 MLX Inference")
parser.add_argument(
"--model-path",
default="checkpoints/mlx-esm2_t33_650M_UR50D",
help="Path to MLX model checkpoint"
)
parser.add_argument(
"--sequence",
default="MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
help="Protein sequence to test (default: human insulin)"
)
parser.add_argument(
"--mask-position",
type=int,
default=None,
help="Position to mask (default: middle of sequence)"
)
args = parser.parse_args()
# Load pretrained ESM-2 model and tokenizer
tokenizer, model = ESM2.from_pretrained(args.model_path)
# Determine sequence and position to mask
sequence = args.sequence.upper()
mask_pos = args.mask_position if args.mask_position is not None else len(sequence) // 2
if mask_pos >= len(sequence):
mask_pos = len(sequence) - 1
original_aa = sequence[mask_pos] # The original amino acid at masked position
# Display input info
print(f"Original sequence: {sequence}")
print(f"Masked sequence: {sequence[:mask_pos]}<mask>{sequence[mask_pos+1:]}")
print(f"Predicting position {mask_pos}: {original_aa}\n")
# Tokenize sequence before and after the mask
before = tokenizer.encode(sequence[:mask_pos], add_special_tokens=False)
after = tokenizer.encode(sequence[mask_pos+1:], add_special_tokens=False)
# Build token sequence with <cls>, <mask>, and <eos>
tokens = mx.array([[tokenizer.cls_id] + before.tolist() + [tokenizer.mask_id] + after.tolist() + [tokenizer.eos_id]])
mask_token_pos = 1 + len(before) # Position of <mask> token
# Run model to get logits for each token position
logits = model(tokens)["logits"]
probs = mx.softmax(logits[0, mask_token_pos, :]) # Softmax over vocabulary at mask position
# Get top-5 most likely tokens
top_indices = mx.argsort(probs)[-5:][::-1]
# Print predictions
print("Top predictions:")
for i, idx in enumerate(top_indices):
token = tokenizer.vocab[int(idx)]
if token in tokenizer.vocab:
prob = float(probs[idx])
marker = "" if token == original_aa else " "
print(f"{marker} {i+1}. {token}: {prob:.3f} ({prob*100:.1f}%)")
if __name__ == "__main__":
main()

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

12
esm/requirements.txt Normal file
View File

@@ -0,0 +1,12 @@
mlx
torch
transformers
numpy
pandas
seaborn
biopython
biotite
scipy
tqdm
scikit-learn
matplotlib

101
esm/test.py Normal file
View File

@@ -0,0 +1,101 @@
import unittest
import numpy as np
from transformers import AutoTokenizer, EsmForMaskedLM, EsmConfig
from esm import ESM2
# Paths for MLX and Hugging Face versions of ESM-2
MLX_PATH = "checkpoints/mlx-esm2_t12_35M_UR50D"
HF_PATH = "facebook/esm2_t12_35M_UR50D"
def load_mlx_model():
"""Load MLX ESM-2 model and tokenizer."""
tokenizer, model = ESM2.from_pretrained(MLX_PATH)
return tokenizer, model
def load_hf_model():
"""Load Hugging Face ESM-2 model and tokenizer with hidden states + attentions."""
tokenizer = AutoTokenizer.from_pretrained(HF_PATH)
config = EsmConfig.from_pretrained(
HF_PATH,
output_hidden_states=True,
output_attentions=True
)
model = EsmForMaskedLM.from_pretrained(HF_PATH, config=config)
return tokenizer, model
class TestESM2(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Load both MLX and HF models/tokenizers once for all tests
cls.mlx_tokenizer, cls.mlx_model = load_mlx_model()
cls.hf_tokenizer, cls.hf_model = load_hf_model()
def test_tokenizer(self):
"""Verify MLX tokenizer matches Hugging Face tokenizer behavior."""
self.assertEqual(len(self.mlx_tokenizer), len(self.hf_tokenizer))
sequences = [
"MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"
]
# Compare batched tokenization (padded sequences)
mlx_batch = self.mlx_tokenizer.batch_encode(sequences)
hf_batch = self.hf_tokenizer(sequences, return_tensors="pt", padding=True)["input_ids"].cpu().numpy()
self.assertEqual(tuple(mlx_batch.shape), tuple(hf_batch.shape))
self.assertTrue(
np.array_equal(np.array(mlx_batch.tolist(), dtype=hf_batch.dtype), hf_batch)
)
# Compare single-sequence encode/decode
for sequence in sequences:
mlx_tokens = self.mlx_tokenizer.encode(sequence)
hf_tokens = self.hf_tokenizer(sequence, return_tensors="pt")["input_ids"].cpu().numpy().tolist()[0]
self.assertTrue(np.array_equal(mlx_tokens, hf_tokens))
self.assertEqual(
self.mlx_tokenizer.decode(mlx_tokens),
self.hf_tokenizer.decode(hf_tokens).replace(" ", "")
)
self.assertEqual(
self.mlx_tokenizer.decode(mlx_tokens, skip_special_tokens=True),
self.hf_tokenizer.decode(hf_tokens, skip_special_tokens=True).replace(" ", "")
)
def test_model(self):
"""Verify MLX and HF model outputs match (logits, hidden states, attentions)."""
sequences = [
"MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"
]
for sequence in sequences:
# Tokenize
mlx_tokens = self.mlx_tokenizer.encode(sequence, return_batch_dim=True)
hf_tokens = self.hf_tokenizer(sequence, return_tensors="pt")["input_ids"]
# Forward pass
mlx_outputs = self.mlx_model(
mlx_tokens,
repr_layers=[self.mlx_model.num_layers],
need_head_weights=True
)
hf_outputs = self.hf_model(input_ids=hf_tokens)
# Compare logits
mlx_logits = np.array(mlx_outputs["logits"])
hf_logits = hf_outputs["logits"].detach().cpu().numpy()
self.assertTrue(np.allclose(mlx_logits, hf_logits, rtol=1e-4, atol=1e-4))
# Compare final-layer hidden states
final_layer = self.mlx_model.num_layers
mlx_hidden_states = np.array(mlx_outputs["representations"][final_layer])
hf_hidden_states = hf_outputs["hidden_states"][-1].detach().cpu().numpy()
self.assertTrue(np.allclose(mlx_hidden_states, hf_hidden_states, rtol=1e-4, atol=1e-4))
# Compare attentions for final layer
mlx_attentions = np.array(mlx_outputs["attentions"][:, final_layer-1, :, :, :])
hf_attentions = hf_outputs["attentions"][-1].detach().cpu().numpy()
self.assertTrue(np.allclose(mlx_attentions, hf_attentions, rtol=1e-4, atol=1e-4))
if __name__ == "__main__":
unittest.main()