{ "cells": [ { "cell_type": "markdown", "id": "3fbacbe4", "metadata": {}, "source": [ "## Predicting Protein Contacts with ESM-2\n", "\n", "Understanding how amino acids interact within a folded protein is essential for grasping protein function and stability. Contact prediction, the task of identifying which residues are close together in three-dimensional space, is a key step in the sequence to structure process. ESM-2’s learned attention patterns capture evolutionary signals that encode structural information, which allows the model to predict residue contacts directly from sequence data.\n", "\n", "In this notebook, we'll explore ESM-2's ability to predict protein contacts across three diverse proteins from different organisms:\n", "\n", "**Bacterial Transport:**\n", "- **1a3a (PTS Mannitol Component)**: A phosphoenolpyruvate-dependent sugar phosphotransferase system component from *E. coli*, essential for carbohydrate metabolism\n", "\n", "**Stress Response:**\n", "- **5ahw (Universal Stress Protein)**: A conserved stress response protein from *Mycolicibacterium smegmatis* that helps cells survive harsh conditions\n", "\n", "**Human Metabolism:**\n", "- **1xcr (Ester Hydrolase)**: A human enzyme (C11orf54) involved in lipid metabolism and cellular signaling pathways\n", "\n", "We will evaluate how effectively ESM-2 captures the structural relationships present in these sequences, measuring precision across different sequence separation ranges to assess both local and long-range contact prediction performance. This notebook is a modified version of a [notebook by the same name](https://github.com/facebookresearch/esm/blob/main/examples/contact_prediction.ipynb) from the [offical ESM repsitory](https://github.com/facebookresearch/esm)." ] }, { "cell_type": "markdown", "id": "08352b12", "metadata": {}, "source": [ "### Setup\n", "\n", "Here we import all neccessary libraries." ] }, { "cell_type": "code", "execution_count": 1, "id": "c1047c94", "metadata": {}, "outputs": [], "source": [ "from typing import List, Tuple, Optional, Dict\n", "import string\n", "\n", "import mlx.core as mx\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "from scipy.spatial.distance import squareform, pdist\n", "import biotite.structure as bs\n", "from biotite.database import rcsb\n", "from biotite.structure.io.pdbx import CIFFile, get_structure\n", "from Bio import SeqIO" ] }, { "cell_type": "markdown", "id": "5f0af076", "metadata": {}, "source": [ "Download multiple sequence alignment (MSA) files for our three test proteins from the ESM repository." ] }, { "cell_type": "code", "execution_count": 2, "id": "3264b66d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " % Total % Received % Xferd Average Speed Time Time Time Current\n", " Dload Upload Total Spent Left Speed\n", "100 147k 100 147k 0 0 536k 0 --:--:-- --:--:-- --:--:-- 538k\n", " % Total % Received % Xferd Average Speed Time Time Time Current\n", " Dload Upload Total Spent Left Speed\n", "100 127k 100 127k 0 0 485k 0 --:--:-- --:--:-- --:--:-- 486k\n", " % Total % Received % Xferd Average Speed Time Time Time Current\n", " Dload Upload Total Spent Left Speed\n", "100 181k 100 181k 0 0 738k 0 --:--:-- --:--:-- --:--:-- 740k\n" ] } ], "source": [ "!mkdir -p data\n", "!curl -o data/1a3a_1_A.a3m https://raw.githubusercontent.com/facebookresearch/esm/main/examples/data/1a3a_1_A.a3m\n", "!curl -o data/5ahw_1_A.a3m https://raw.githubusercontent.com/facebookresearch/esm/main/examples/data/5ahw_1_A.a3m\n", "!curl -o data/1xcr_1_A.a3m https://raw.githubusercontent.com/facebookresearch/esm/main/examples/data/1xcr_1_A.a3m" ] }, { "cell_type": "markdown", "id": "cbf1d0cb", "metadata": {}, "source": [ "### Loading the model\n", "\n", "Load the ESM-2 model. Change the path below to point to your converted checkpoint." ] }, { "cell_type": "code", "execution_count": 3, "id": "4406e8a0", "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append(\"..\")\n", "\n", "from esm import ESM2\n", "\n", "esm_checkpoint = \"../checkpoints/mlx-esm2_t33_650M_UR50D\"\n", "tokenizer, model = ESM2.from_pretrained(esm_checkpoint)" ] }, { "cell_type": "markdown", "id": "77596456", "metadata": {}, "source": [ "### Defining functions" ] }, { "cell_type": "markdown", "id": "eb5f07ed", "metadata": {}, "source": [ "#### Parsing alignments" ] }, { "cell_type": "markdown", "id": "e754abd7", "metadata": {}, "source": [ "This function parses multiple sequence alignment files and clean up insertion artifacts. MSA files often contain lowercase letters and special characters (`.`, `*`) to indicate insertions relative to the reference sequence - these need to be removed to get the core aligned sequences." ] }, { "cell_type": "code", "execution_count": 4, "id": "43717bea", "metadata": {}, "outputs": [], "source": [ "deletekeys = dict.fromkeys(string.ascii_lowercase)\n", "deletekeys[\".\"] = None\n", "deletekeys[\"*\"] = None\n", "translation = str.maketrans(deletekeys)\n", "\n", "def read_sequence(filename: str) -> Tuple[str, str]:\n", " \"\"\" Reads the first (reference) sequences from a fasta or MSA file.\"\"\"\n", " record = next(SeqIO.parse(filename, \"fasta\"))\n", " return record.description, str(record.seq)\n", "\n", "def remove_insertions(sequence: str) -> str:\n", " \"\"\" Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. \"\"\"\n", " return sequence.translate(translation)\n", "\n", "def read_msa(filename: str) -> List[Tuple[str, str]]:\n", " \"\"\" Reads the sequences from an MSA file, automatically removes insertions.\"\"\"\n", " return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, \"fasta\")]" ] }, { "cell_type": "markdown", "id": "628d7de1", "metadata": {}, "source": [ "#### Converting structures to contacts\n", "\n", "There are many ways to define a protein contact. Here we're using the definition of 8 angstroms between carbon beta atoms. Note that the position of the carbon beta is imputed from the position of the N, CA, and C atoms for each residue." ] }, { "cell_type": "code", "execution_count": 5, "id": "21e0b44b", "metadata": {}, "outputs": [], "source": [ "def extend(a, b, c, L, A, D):\n", " \"\"\"\n", " input: 3 coords (a,b,c), (L)ength, (A)ngle, and (D)ihedral\n", " output: 4th coord\n", " \"\"\"\n", " def normalize(x):\n", " return x / np.linalg.norm(x, ord=2, axis=-1, keepdims=True)\n", "\n", " bc = normalize(b - c)\n", " n = normalize(np.cross(b - a, bc))\n", " m = [bc, np.cross(n, bc), n]\n", " d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)]\n", " return c + sum([m * d for m, d in zip(m, d)])\n", "\n", "def contacts_from_pdb(\n", " structure: bs.AtomArray,\n", " distance_threshold: float = 8.0,\n", " chain: Optional[str] = None,\n", ") -> np.ndarray:\n", " \"\"\"Extract contacts from PDB structure.\"\"\"\n", " mask = ~structure.hetero\n", " if chain is not None:\n", " mask &= structure.chain_id == chain\n", "\n", " N = structure.coord[mask & (structure.atom_name == \"N\")]\n", " CA = structure.coord[mask & (structure.atom_name == \"CA\")]\n", " C = structure.coord[mask & (structure.atom_name == \"C\")]\n", "\n", " Cbeta = extend(C, N, CA, 1.522, 1.927, -2.143)\n", " dist = squareform(pdist(Cbeta))\n", " \n", " contacts = dist < distance_threshold\n", " contacts = contacts.astype(np.int64)\n", " contacts[np.isnan(dist)] = -1\n", " return contacts" ] }, { "cell_type": "markdown", "id": "5473f306", "metadata": {}, "source": [ "#### Computing contact precisions" ] }, { "cell_type": "markdown", "id": "e361a9f3", "metadata": {}, "source": [ "Calculate precision metrics to evaluate contact prediction quality. The `compute_precisions` function ranks predicted contacts by confidence and measures how many of the top predictions are true contacts, while `evaluate_prediction` breaks this down by sequence separation ranges (local, short, medium, long-range) since predicting distant contacts is typically much harder than nearby ones." ] }, { "cell_type": "code", "execution_count": 6, "id": "62c37bbd", "metadata": {}, "outputs": [], "source": [ "def compute_precisions(\n", " predictions: mx.array,\n", " targets: mx.array,\n", " minsep: int = 6,\n", " maxsep: Optional[int] = None,\n", " override_length: Optional[int] = None,\n", ") -> Dict[str, mx.array]:\n", " \"\"\"Compute precision metrics for contact prediction.\"\"\"\n", " batch_size, seqlen, _ = predictions.shape\n", " \n", " if maxsep is not None:\n", " sep_mask_2d = mx.abs(mx.arange(seqlen)[None, :] - mx.arange(seqlen)[:, None]) <= maxsep\n", " targets = targets * sep_mask_2d[None, :]\n", " \n", " targets = targets.astype(mx.float32)\n", " src_lengths = (targets >= 0).sum(axis=-1).sum(axis=-1).astype(mx.float32)\n", " \n", " x_ind, y_ind = [], []\n", " for i in range(seqlen):\n", " for j in range(i + minsep, seqlen):\n", " x_ind.append(i)\n", " y_ind.append(j)\n", " \n", " x_ind = mx.array(x_ind)\n", " y_ind = mx.array(y_ind)\n", " \n", " predictions_upper = predictions[:, x_ind, y_ind]\n", " targets_upper = targets[:, x_ind, y_ind]\n", "\n", " topk = seqlen if override_length is None else max(seqlen, override_length)\n", " indices = mx.argsort(predictions_upper, axis=-1)[:, ::-1][:, :topk]\n", " \n", " batch_indices = mx.arange(batch_size)[:, None]\n", " topk_targets = targets_upper[batch_indices, indices]\n", " \n", " if topk_targets.shape[1] < topk:\n", " pad_shape = (topk_targets.shape[0], topk - topk_targets.shape[1])\n", " padding = mx.zeros(pad_shape)\n", " topk_targets = mx.concatenate([topk_targets, padding], 1)\n", "\n", " cumulative_dist = mx.cumsum(topk_targets, -1)\n", "\n", " gather_lengths = src_lengths[:, None]\n", " if override_length is not None:\n", " gather_lengths = override_length * mx.ones_like(gather_lengths)\n", "\n", " precision_fractions = mx.arange(0.1, 1.1, 0.1)\n", " gather_indices = (precision_fractions[None, :] * gather_lengths) - 1\n", " gather_indices = mx.clip(gather_indices, 0, cumulative_dist.shape[1] - 1)\n", " gather_indices = gather_indices.astype(mx.int32)\n", "\n", " binned_cumulative_dist = cumulative_dist[batch_indices, gather_indices]\n", " binned_precisions = binned_cumulative_dist / (gather_indices + 1)\n", "\n", " pl5 = binned_precisions[:, 1]\n", " pl2 = binned_precisions[:, 4]\n", " pl = binned_precisions[:, 9]\n", " auc = binned_precisions.mean(-1)\n", "\n", " return {\"AUC\": auc, \"P@L\": pl, \"P@L2\": pl2, \"P@L5\": pl5}\n", "\n", "def evaluate_prediction(\n", " predictions: mx.array,\n", " targets: mx.array,\n", ") -> Dict[str, float]:\n", " \"\"\"Evaluate contact predictions across different sequence separation ranges.\"\"\"\n", " contact_ranges = [\n", " (\"local\", 3, 6),\n", " (\"short\", 6, 12),\n", " (\"medium\", 12, 24),\n", " (\"long\", 24, None),\n", " ]\n", " metrics = {}\n", " \n", " for name, minsep, maxsep in contact_ranges:\n", " rangemetrics = compute_precisions(\n", " predictions,\n", " targets,\n", " minsep=minsep,\n", " maxsep=maxsep,\n", " )\n", " for key, val in rangemetrics.items():\n", " metrics[f\"{name}_{key}\"] = float(val[0])\n", " return metrics" ] }, { "cell_type": "markdown", "id": "5873e052", "metadata": {}, "source": [ "#### Predicting contacts" ] }, { "cell_type": "markdown", "id": "2d5778a9", "metadata": {}, "source": [ "This function wraps the tokenization and model inference steps, converting a raw amino acid sequence into token IDs and passing them through ESM-2's contact prediction head to produce a contact probability matrix." ] }, { "cell_type": "code", "execution_count": 7, "id": "dddf31a7", "metadata": {}, "outputs": [], "source": [ "def predict_contacts(sequence: str, model, tokenizer) -> mx.array:\n", " tokens = tokenizer.encode(sequence)\n", " tokens = mx.array([tokens])\n", " contacts = model.predict_contacts(tokens)\n", " return contacts" ] }, { "cell_type": "markdown", "id": "62562401", "metadata": {}, "source": [ "#### Plotting results\n", "\n", "This function visualizes contacts as a symmetric matrix where both axes index residue positions. The lower triangle shows the model’s confidence as a blue heatmap, with darker cells indicating higher confidence. The upper triangle overlays evaluation markers: blue dots are correctly predicted contacts (true positives), red dots are predicted but not real (false positives), and grey dots are real contacts the model missed (false negatives)." ] }, { "cell_type": "code", "execution_count": 8, "id": "03e03791", "metadata": {}, "outputs": [], "source": [ "def plot_contacts_and_predictions(\n", " predictions: mx.array,\n", " contacts: np.ndarray,\n", " ax,\n", " title: str,\n", " cmap: str = \"Blues\",\n", " ms: float = 1,\n", "):\n", " \"\"\"Plot contact predictions and true contacts.\"\"\"\n", " if isinstance(predictions, mx.array):\n", " predictions = np.array(predictions)\n", " \n", " seqlen = contacts.shape[0]\n", " relative_distance = np.add.outer(-np.arange(seqlen), np.arange(seqlen))\n", " bottom_mask = relative_distance < 0\n", " masked_image = np.ma.masked_where(bottom_mask, predictions)\n", " invalid_mask = np.abs(np.add.outer(np.arange(seqlen), -np.arange(seqlen))) < 6\n", " predictions_copy = predictions.copy()\n", " predictions_copy[invalid_mask] = float(\"-inf\")\n", "\n", " topl_val = np.sort(predictions_copy.reshape(-1))[-seqlen]\n", " pred_contacts = predictions_copy >= topl_val\n", " true_positives = contacts & pred_contacts & ~bottom_mask\n", " false_positives = ~contacts & pred_contacts & ~bottom_mask\n", " other_contacts = contacts & ~pred_contacts & ~bottom_mask\n", "\n", " ax.imshow(masked_image, cmap=cmap)\n", " ax.plot(*np.where(other_contacts), \"o\", c=\"grey\", ms=ms)\n", " ax.plot(*np.where(false_positives), \"o\", c=\"r\", ms=ms)\n", " ax.plot(*np.where(true_positives), \"o\", c=\"b\", ms=ms)\n", " ax.set_title(title)\n", " ax.axis(\"square\")\n", " ax.set_xlim([0, seqlen])\n", " ax.set_ylim([0, seqlen])" ] }, { "cell_type": "markdown", "id": "9364c984", "metadata": {}, "source": [ "### Predict and visualize\n", "Here we'll use ESM-2 contact prediction on our three test proteins and evaluate the results. We'll compute precision metrics across different sequence separation ranges and create contact maps that visualize both the model's predictions and how well they match the true protein structures." ] }, { "cell_type": "markdown", "id": "9fa9e59e", "metadata": {}, "source": [ "#### Read Data" ] }, { "cell_type": "markdown", "id": "7da50dc2", "metadata": {}, "source": [ "Load experimental protein structures from the Protein Data Bank and extract true contact maps for evaluation, while also parsing the reference sequences from our MSA files that will serve as input to ESM-2." ] }, { "cell_type": "code", "execution_count": 9, "id": "2d276137", "metadata": {}, "outputs": [], "source": [ "PDB_IDS = [\"1a3a\", \"5ahw\", \"1xcr\"]\n", "\n", "structures = {\n", " name.lower(): get_structure(CIFFile.read(rcsb.fetch(name, \"cif\")))[0]\n", " for name in PDB_IDS\n", "}\n", "\n", "contacts = {\n", " name: contacts_from_pdb(structure, chain=\"A\") \n", " for name, structure in structures.items()\n", "}\n", "\n", "msas = {\n", " name: read_msa(f\"data/{name.lower()}_1_A.a3m\")\n", " for name in PDB_IDS\n", "}\n", "\n", "sequences = {\n", " name: msa[0] for name, msa in msas.items()\n", "}" ] }, { "cell_type": "markdown", "id": "4ce64f18", "metadata": {}, "source": [ "#### ESM-2 predictions" ] }, { "cell_type": "markdown", "id": "1f2da88f", "metadata": {}, "source": [ "##### Evaluate predictions" ] }, { "cell_type": "markdown", "id": "0adb0a11", "metadata": {}, "source": [ "This loop generates contact predictions for each protein using ESM-2, compares them against the experimentally determined structures, and computes precision metrics across different sequence separation ranges to evaluate model performance." ] }, { "cell_type": "code", "execution_count": 10, "id": "941b4afa", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | id | \n", "model | \n", "local_AUC | \n", "local_P@L | \n", "local_P@L2 | \n", "local_P@L5 | \n", "short_AUC | \n", "short_P@L | \n", "short_P@L2 | \n", "short_P@L5 | \n", "medium_AUC | \n", "medium_P@L | \n", "medium_P@L2 | \n", "medium_P@L5 | \n", "long_AUC | \n", "long_P@L | \n", "long_P@L2 | \n", "long_P@L5 | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "1a3a | \n", "ESM-2 (Unsupervised) | \n", "0.193103 | \n", "0.193103 | \n", "0.193103 | \n", "0.193103 | \n", "0.172414 | \n", "0.172414 | \n", "0.172414 | \n", "0.172414 | \n", "0.262069 | \n", "0.262069 | \n", "0.262069 | \n", "0.262069 | \n", "0.689655 | \n", "0.689655 | \n", "0.689655 | \n", "0.689655 | \n", "
| 1 | \n", "5ahw | \n", "ESM-2 (Unsupervised) | \n", "0.024000 | \n", "0.024000 | \n", "0.024000 | \n", "0.024000 | \n", "0.136000 | \n", "0.136000 | \n", "0.136000 | \n", "0.136000 | \n", "0.144000 | \n", "0.144000 | \n", "0.144000 | \n", "0.144000 | \n", "0.864000 | \n", "0.864000 | \n", "0.864000 | \n", "0.864000 | \n", "
| 2 | \n", "1xcr | \n", "ESM-2 (Unsupervised) | \n", "0.111821 | \n", "0.111821 | \n", "0.111821 | \n", "0.111821 | \n", "0.159744 | \n", "0.159744 | \n", "0.159744 | \n", "0.159744 | \n", "0.175719 | \n", "0.175719 | \n", "0.175719 | \n", "0.175719 | \n", "0.408946 | \n", "0.408946 | \n", "0.408946 | \n", "0.408946 | \n", "