mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-21 20:46:50 +08:00
603 lines
22 KiB
Plaintext
603 lines
22 KiB
Plaintext
![]() |
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "3fbacbe4",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Predicting Protein Contacts with ESM-2\n",
|
|||
|
"\n",
|
|||
|
"Understanding how amino acids interact within a folded protein is essential for grasping protein function and stability. Contact prediction, the task of identifying which residues are close together in three-dimensional space, is a key step in the sequence to structure process. ESM-2’s learned attention patterns capture evolutionary signals that encode structural information, which allows the model to predict residue contacts directly from sequence data.\n",
|
|||
|
"\n",
|
|||
|
"In this notebook, we'll explore ESM-2's ability to predict protein contacts across three diverse proteins from different organisms:\n",
|
|||
|
"\n",
|
|||
|
"**Bacterial Transport:**\n",
|
|||
|
"- **1a3a (PTS Mannitol Component)**: A phosphoenolpyruvate-dependent sugar phosphotransferase system component from *E. coli*, essential for carbohydrate metabolism\n",
|
|||
|
"\n",
|
|||
|
"**Stress Response:**\n",
|
|||
|
"- **5ahw (Universal Stress Protein)**: A conserved stress response protein from *Mycolicibacterium smegmatis* that helps cells survive harsh conditions\n",
|
|||
|
"\n",
|
|||
|
"**Human Metabolism:**\n",
|
|||
|
"- **1xcr (Ester Hydrolase)**: A human enzyme (C11orf54) involved in lipid metabolism and cellular signaling pathways\n",
|
|||
|
"\n",
|
|||
|
"We will evaluate how effectively ESM-2 captures the structural relationships present in these sequences, measuring precision across different sequence separation ranges to assess both local and long-range contact prediction performance. This notebook is a modified version of a [notebook by the same name](https://github.com/facebookresearch/esm/blob/main/examples/contact_prediction.ipynb) from the [offical ESM repsitory](https://github.com/facebookresearch/esm)."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "08352b12",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Setup\n",
|
|||
|
"\n",
|
|||
|
"Here we import all neccessary libraries."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "c1047c94",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"ename": "",
|
|||
|
"evalue": "",
|
|||
|
"output_type": "error",
|
|||
|
"traceback": [
|
|||
|
"\u001b[1;31mRunning cells with '.venv (Python 3.11.13)' requires the ipykernel package.\n",
|
|||
|
"\u001b[1;31mInstall 'ipykernel' into the Python environment. \n",
|
|||
|
"\u001b[1;31mCommand: '/Users/vincent/Developer/mlx-examples/.venv/bin/python -m pip install ipykernel -U --force-reinstall'"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from typing import List, Tuple, Optional, Dict\n",
|
|||
|
"import string\n",
|
|||
|
"\n",
|
|||
|
"import mlx.core as mx\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"from scipy.spatial.distance import squareform, pdist\n",
|
|||
|
"import biotite.structure as bs\n",
|
|||
|
"from biotite.database import rcsb\n",
|
|||
|
"from biotite.structure.io.pdbx import CIFFile, get_structure\n",
|
|||
|
"from Bio import SeqIO"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "5f0af076",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Download multiple sequence alignment (MSA) files for our three test proteins from the ESM repository."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "3264b66d",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"!mkdir -p data\n",
|
|||
|
"!curl -o data/1a3a_1_A.a3m https://raw.githubusercontent.com/facebookresearch/esm/main/examples/data/1a3a_1_A.a3m\n",
|
|||
|
"!curl -o data/5ahw_1_A.a3m https://raw.githubusercontent.com/facebookresearch/esm/main/examples/data/5ahw_1_A.a3m\n",
|
|||
|
"!curl -o data/1xcr_1_A.a3m https://raw.githubusercontent.com/facebookresearch/esm/main/examples/data/1xcr_1_A.a3m"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "cbf1d0cb",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Loading the model\n",
|
|||
|
"\n",
|
|||
|
"Load the ESM-2 model. Here we will use the 650M parameter version. Change the path below to point to your converted checkpoint."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "4406e8a0",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import sys\n",
|
|||
|
"sys.path.append(\"..\")\n",
|
|||
|
"\n",
|
|||
|
"from esm import ESM2\n",
|
|||
|
"\n",
|
|||
|
"esm_checkpoint = \"../checkpoints/mlx-esm2_t33_650M_UR50D\"\n",
|
|||
|
"tokenizer, model = ESM2.from_pretrained(esm_checkpoint)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "77596456",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Defining functions"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "eb5f07ed",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Parsing alignments"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "e754abd7",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"This function parses multiple sequence alignment files and clean up insertion artifacts. MSA files often contain lowercase letters and special characters (`.`, `*`) to indicate insertions relative to the reference sequence - these need to be removed to get the core aligned sequences."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "43717bea",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"deletekeys = dict.fromkeys(string.ascii_lowercase)\n",
|
|||
|
"deletekeys[\".\"] = None\n",
|
|||
|
"deletekeys[\"*\"] = None\n",
|
|||
|
"translation = str.maketrans(deletekeys)\n",
|
|||
|
"\n",
|
|||
|
"def read_sequence(filename: str) -> Tuple[str, str]:\n",
|
|||
|
" \"\"\" Reads the first (reference) sequences from a fasta or MSA file.\"\"\"\n",
|
|||
|
" record = next(SeqIO.parse(filename, \"fasta\"))\n",
|
|||
|
" return record.description, str(record.seq)\n",
|
|||
|
"\n",
|
|||
|
"def remove_insertions(sequence: str) -> str:\n",
|
|||
|
" \"\"\" Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. \"\"\"\n",
|
|||
|
" return sequence.translate(translation)\n",
|
|||
|
"\n",
|
|||
|
"def read_msa(filename: str) -> List[Tuple[str, str]]:\n",
|
|||
|
" \"\"\" Reads the sequences from an MSA file, automatically removes insertions.\"\"\"\n",
|
|||
|
" return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, \"fasta\")]"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "628d7de1",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Converting structures to contacts\n",
|
|||
|
"\n",
|
|||
|
"There are many ways to define a protein contact. Here we're using the definition of 8 angstroms between carbon beta atoms. Note that the position of the carbon beta is imputed from the position of the N, CA, and C atoms for each residue."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "21e0b44b",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def extend(a, b, c, L, A, D):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" input: 3 coords (a,b,c), (L)ength, (A)ngle, and (D)ihedral\n",
|
|||
|
" output: 4th coord\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" def normalize(x):\n",
|
|||
|
" return x / np.linalg.norm(x, ord=2, axis=-1, keepdims=True)\n",
|
|||
|
"\n",
|
|||
|
" bc = normalize(b - c)\n",
|
|||
|
" n = normalize(np.cross(b - a, bc))\n",
|
|||
|
" m = [bc, np.cross(n, bc), n]\n",
|
|||
|
" d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)]\n",
|
|||
|
" return c + sum([m * d for m, d in zip(m, d)])\n",
|
|||
|
"\n",
|
|||
|
"def contacts_from_pdb(\n",
|
|||
|
" structure: bs.AtomArray,\n",
|
|||
|
" distance_threshold: float = 8.0,\n",
|
|||
|
" chain: Optional[str] = None,\n",
|
|||
|
") -> np.ndarray:\n",
|
|||
|
" \"\"\"Extract contacts from PDB structure.\"\"\"\n",
|
|||
|
" mask = ~structure.hetero\n",
|
|||
|
" if chain is not None:\n",
|
|||
|
" mask &= structure.chain_id == chain\n",
|
|||
|
"\n",
|
|||
|
" N = structure.coord[mask & (structure.atom_name == \"N\")]\n",
|
|||
|
" CA = structure.coord[mask & (structure.atom_name == \"CA\")]\n",
|
|||
|
" C = structure.coord[mask & (structure.atom_name == \"C\")]\n",
|
|||
|
"\n",
|
|||
|
" Cbeta = extend(C, N, CA, 1.522, 1.927, -2.143)\n",
|
|||
|
" dist = squareform(pdist(Cbeta))\n",
|
|||
|
" \n",
|
|||
|
" contacts = dist < distance_threshold\n",
|
|||
|
" contacts = contacts.astype(np.int64)\n",
|
|||
|
" contacts[np.isnan(dist)] = -1\n",
|
|||
|
" return contacts"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "5473f306",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Computing contact precisions"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "e361a9f3",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Calculate precision metrics to evaluate contact prediction quality. The `compute_precisions` function ranks predicted contacts by confidence and measures how many of the top predictions are true contacts, while `evaluate_prediction` breaks this down by sequence separation ranges (local, short, medium, long-range) since predicting distant contacts is typically much harder than nearby ones."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "62c37bbd",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def compute_precisions(\n",
|
|||
|
" predictions: mx.array,\n",
|
|||
|
" targets: mx.array,\n",
|
|||
|
" minsep: int = 6,\n",
|
|||
|
" maxsep: Optional[int] = None,\n",
|
|||
|
" override_length: Optional[int] = None,\n",
|
|||
|
") -> Dict[str, mx.array]:\n",
|
|||
|
" \"\"\"Compute precision metrics for contact prediction.\"\"\"\n",
|
|||
|
" batch_size, seqlen, _ = predictions.shape\n",
|
|||
|
" \n",
|
|||
|
" if maxsep is not None:\n",
|
|||
|
" sep_mask_2d = mx.abs(mx.arange(seqlen)[None, :] - mx.arange(seqlen)[:, None]) <= maxsep\n",
|
|||
|
" targets = targets * sep_mask_2d[None, :]\n",
|
|||
|
" \n",
|
|||
|
" targets = targets.astype(mx.float32)\n",
|
|||
|
" src_lengths = (targets >= 0).sum(axis=-1).sum(axis=-1).astype(mx.float32)\n",
|
|||
|
" \n",
|
|||
|
" x_ind, y_ind = [], []\n",
|
|||
|
" for i in range(seqlen):\n",
|
|||
|
" for j in range(i + minsep, seqlen):\n",
|
|||
|
" x_ind.append(i)\n",
|
|||
|
" y_ind.append(j)\n",
|
|||
|
" \n",
|
|||
|
" x_ind = mx.array(x_ind)\n",
|
|||
|
" y_ind = mx.array(y_ind)\n",
|
|||
|
" \n",
|
|||
|
" predictions_upper = predictions[:, x_ind, y_ind]\n",
|
|||
|
" targets_upper = targets[:, x_ind, y_ind]\n",
|
|||
|
"\n",
|
|||
|
" topk = seqlen if override_length is None else max(seqlen, override_length)\n",
|
|||
|
" indices = mx.argsort(predictions_upper, axis=-1)[:, ::-1][:, :topk]\n",
|
|||
|
" \n",
|
|||
|
" batch_indices = mx.arange(batch_size)[:, None]\n",
|
|||
|
" topk_targets = targets_upper[batch_indices, indices]\n",
|
|||
|
" \n",
|
|||
|
" if topk_targets.shape[1] < topk:\n",
|
|||
|
" pad_shape = (topk_targets.shape[0], topk - topk_targets.shape[1])\n",
|
|||
|
" padding = mx.zeros(pad_shape)\n",
|
|||
|
" topk_targets = mx.concatenate([topk_targets, padding], 1)\n",
|
|||
|
"\n",
|
|||
|
" cumulative_dist = mx.cumsum(topk_targets, -1)\n",
|
|||
|
"\n",
|
|||
|
" gather_lengths = src_lengths[:, None]\n",
|
|||
|
" if override_length is not None:\n",
|
|||
|
" gather_lengths = override_length * mx.ones_like(gather_lengths)\n",
|
|||
|
"\n",
|
|||
|
" precision_fractions = mx.arange(0.1, 1.1, 0.1)\n",
|
|||
|
" gather_indices = (precision_fractions[None, :] * gather_lengths) - 1\n",
|
|||
|
" gather_indices = mx.clip(gather_indices, 0, cumulative_dist.shape[1] - 1)\n",
|
|||
|
" gather_indices = gather_indices.astype(mx.int32)\n",
|
|||
|
"\n",
|
|||
|
" binned_cumulative_dist = cumulative_dist[batch_indices, gather_indices]\n",
|
|||
|
" binned_precisions = binned_cumulative_dist / (gather_indices + 1)\n",
|
|||
|
"\n",
|
|||
|
" pl5 = binned_precisions[:, 1]\n",
|
|||
|
" pl2 = binned_precisions[:, 4]\n",
|
|||
|
" pl = binned_precisions[:, 9]\n",
|
|||
|
" auc = binned_precisions.mean(-1)\n",
|
|||
|
"\n",
|
|||
|
" return {\"AUC\": auc, \"P@L\": pl, \"P@L2\": pl2, \"P@L5\": pl5}\n",
|
|||
|
"\n",
|
|||
|
"def evaluate_prediction(\n",
|
|||
|
" predictions: mx.array,\n",
|
|||
|
" targets: mx.array,\n",
|
|||
|
") -> Dict[str, float]:\n",
|
|||
|
" \"\"\"Evaluate contact predictions across different sequence separation ranges.\"\"\"\n",
|
|||
|
" contact_ranges = [\n",
|
|||
|
" (\"local\", 3, 6),\n",
|
|||
|
" (\"short\", 6, 12),\n",
|
|||
|
" (\"medium\", 12, 24),\n",
|
|||
|
" (\"long\", 24, None),\n",
|
|||
|
" ]\n",
|
|||
|
" metrics = {}\n",
|
|||
|
" \n",
|
|||
|
" for name, minsep, maxsep in contact_ranges:\n",
|
|||
|
" rangemetrics = compute_precisions(\n",
|
|||
|
" predictions,\n",
|
|||
|
" targets,\n",
|
|||
|
" minsep=minsep,\n",
|
|||
|
" maxsep=maxsep,\n",
|
|||
|
" )\n",
|
|||
|
" for key, val in rangemetrics.items():\n",
|
|||
|
" metrics[f\"{name}_{key}\"] = float(val[0])\n",
|
|||
|
" return metrics"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "5873e052",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Predicting contacts"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "2d5778a9",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"This function wraps the tokenization and model inference steps, converting a raw amino acid sequence into token IDs and passing them through ESM-2's contact prediction head to produce a contact probability matrix."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "dddf31a7",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def predict_contacts(sequence: str, model, tokenizer) -> mx.array:\n",
|
|||
|
" \"\"\" Predict contacts for a given sequence \"\"\"\n",
|
|||
|
" tokens = tokenizer.encode(sequence)\n",
|
|||
|
" contacts = model.predict_contacts(tokens)\n",
|
|||
|
" return contacts"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "62562401",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Plotting results\n",
|
|||
|
"\n",
|
|||
|
"This function visualizes contacts as a symmetric matrix where both axes index residue positions. The lower triangle shows the model’s confidence as a blue heatmap, with darker cells indicating higher confidence. The upper triangle overlays evaluation markers: blue dots are correctly predicted contacts (true positives), red dots are predicted but not real (false positives), and grey dots are real contacts the model missed (false negatives)."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "03e03791",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def plot_contacts_and_predictions(\n",
|
|||
|
" predictions: mx.array,\n",
|
|||
|
" contacts: np.ndarray,\n",
|
|||
|
" ax,\n",
|
|||
|
" title: str,\n",
|
|||
|
" cmap: str = \"Blues\",\n",
|
|||
|
" ms: float = 1,\n",
|
|||
|
"):\n",
|
|||
|
" \"\"\"Plot contact predictions and true contacts.\"\"\"\n",
|
|||
|
" if isinstance(predictions, mx.array):\n",
|
|||
|
" predictions = np.array(predictions)\n",
|
|||
|
" \n",
|
|||
|
" seqlen = contacts.shape[0]\n",
|
|||
|
" relative_distance = np.add.outer(-np.arange(seqlen), np.arange(seqlen))\n",
|
|||
|
" bottom_mask = relative_distance < 0\n",
|
|||
|
" masked_image = np.ma.masked_where(bottom_mask, predictions)\n",
|
|||
|
" invalid_mask = np.abs(np.add.outer(np.arange(seqlen), -np.arange(seqlen))) < 6\n",
|
|||
|
" predictions_copy = predictions.copy()\n",
|
|||
|
" predictions_copy[invalid_mask] = float(\"-inf\")\n",
|
|||
|
"\n",
|
|||
|
" topl_val = np.sort(predictions_copy.reshape(-1))[-seqlen]\n",
|
|||
|
" pred_contacts = predictions_copy >= topl_val\n",
|
|||
|
" true_positives = contacts & pred_contacts & ~bottom_mask\n",
|
|||
|
" false_positives = ~contacts & pred_contacts & ~bottom_mask\n",
|
|||
|
" other_contacts = contacts & ~pred_contacts & ~bottom_mask\n",
|
|||
|
"\n",
|
|||
|
" ax.imshow(masked_image, cmap=cmap)\n",
|
|||
|
" ax.plot(*np.where(other_contacts), \"o\", c=\"grey\", ms=ms)\n",
|
|||
|
" ax.plot(*np.where(false_positives), \"o\", c=\"r\", ms=ms)\n",
|
|||
|
" ax.plot(*np.where(true_positives), \"o\", c=\"b\", ms=ms)\n",
|
|||
|
" ax.set_title(title)\n",
|
|||
|
" ax.axis(\"square\")\n",
|
|||
|
" ax.set_xlim([0, seqlen])\n",
|
|||
|
" ax.set_ylim([0, seqlen])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "9364c984",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Predict and visualize\n",
|
|||
|
"Here we'll use ESM-2 contact prediction on our three test proteins and evaluate the results. We'll compute precision metrics across different sequence separation ranges and create contact maps that visualize both the model's predictions and how well they match the true protein structures."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "9fa9e59e",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Read Data"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "7da50dc2",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Load experimental protein structures from the Protein Data Bank and extract true contact maps for evaluation, while also parsing the reference sequences from our MSA files that will serve as input to ESM-2."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "2d276137",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"PDB_IDS = [\"1a3a\", \"5ahw\", \"1xcr\"]\n",
|
|||
|
"\n",
|
|||
|
"structures = {\n",
|
|||
|
" name.lower(): get_structure(CIFFile.read(rcsb.fetch(name, \"cif\")))[0]\n",
|
|||
|
" for name in PDB_IDS\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"contacts = {\n",
|
|||
|
" name: contacts_from_pdb(structure, chain=\"A\") \n",
|
|||
|
" for name, structure in structures.items()\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"msas = {\n",
|
|||
|
" name: read_msa(f\"data/{name.lower()}_1_A.a3m\")\n",
|
|||
|
" for name in PDB_IDS\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"sequences = {\n",
|
|||
|
" name: msa[0] for name, msa in msas.items()\n",
|
|||
|
"}"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "4ce64f18",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### ESM-2 predictions"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "1f2da88f",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"##### Evaluate predictions"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "0adb0a11",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"This loop generates contact predictions for each protein using ESM-2, compares them against the experimentally determined structures, and computes precision metrics across different sequence separation ranges to evaluate model performance."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "941b4afa",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"predictions = {}\n",
|
|||
|
"results = []\n",
|
|||
|
"\n",
|
|||
|
"for pdb_id in sequences:\n",
|
|||
|
" _, sequence = sequences[pdb_id]\n",
|
|||
|
" prediction = predict_contacts(sequence, model, tokenizer)\n",
|
|||
|
" predictions[pdb_id] = prediction[0]\n",
|
|||
|
" \n",
|
|||
|
" true_contacts = mx.array(contacts[pdb_id])\n",
|
|||
|
" \n",
|
|||
|
" min_len = min(prediction.shape[1], true_contacts.shape[0])\n",
|
|||
|
" pred_trimmed = prediction[:, :min_len, :min_len]\n",
|
|||
|
" true_trimmed = true_contacts[:min_len, :min_len]\n",
|
|||
|
" true_trimmed = mx.expand_dims(true_trimmed, axis=0)\n",
|
|||
|
" \n",
|
|||
|
" metrics = evaluate_prediction(pred_trimmed, true_trimmed)\n",
|
|||
|
" result = {\"id\": pdb_id, \"model\": \"ESM-2 (Unsupervised)\"}\n",
|
|||
|
" result.update(metrics)\n",
|
|||
|
" results.append(result)\n",
|
|||
|
"\n",
|
|||
|
"results_df = pd.DataFrame(results)\n",
|
|||
|
"display(results_df)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "c5c7418a",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"The results demonstrate that ESM-2 excels at predicting long-range contacts, with precision scores ranging from 40.9% to 86.4% for residues more than 24 positions apart. Performance is consistently higher for distant contacts compared to local ones. For example, the universal stress protein (5ahw) achieves 86.4% precision for long-range contacts but only 2.4% for local contacts between 3 and 6 residues apart. This trend is observed across all three proteins, with medium-range contacts (12–24 residues apart) and short-range contacts (6–12 residues apart) showing intermediate accuracy. These results suggest that ESM-2 has learned to identify evolutionarily conserved structural motifs that connect distant regions of the sequence, which are often critical for protein fold stability and function."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "487cff51",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"##### Plot contacts and predictions"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "10291191",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"This analysis generates contact map visualizations for all three proteins, presenting ESM-2’s predictions as heatmaps and overlaying the true experimental contacts as colored dots."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "628efc10",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"proteins = [r['id'] for r in results]\n",
|
|||
|
"fig, axes = plt.subplots(figsize=(6 * len(proteins), 6), ncols=len(proteins))\n",
|
|||
|
"if len(proteins) == 1:\n",
|
|||
|
" axes = [axes]\n",
|
|||
|
"\n",
|
|||
|
"for ax, pdb_id in zip(axes, proteins):\n",
|
|||
|
" prediction = predictions[pdb_id]\n",
|
|||
|
" target = contacts[pdb_id]\n",
|
|||
|
" \n",
|
|||
|
" result = next(r for r in results if r['id'] == pdb_id)\n",
|
|||
|
" long_pl = result['long_P@L']\n",
|
|||
|
" \n",
|
|||
|
" plot_contacts_and_predictions(\n",
|
|||
|
" prediction, target, ax=ax, \n",
|
|||
|
" title=f\"{pdb_id}: Long Range P@L: {100 * long_pl:.1f}%\"\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
"plt.tight_layout()\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "99e1edaf",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"The contact maps highlight ESM-2’s strong ability to detect long-range structural relationships. In each panel, the lower triangle shows model predictions, where darker blue regions indicate high-confidence contacts, and the upper triangle shows the corresponding experimental data. Correct predictions appear as blue dots, forming distinct off-diagonal patterns in 5ahw and 1a3a that capture key global fold interactions. Red dots mark false positives, which are relatively rare, while grey dots represent missed contacts. These missed contacts are notably more frequent in 1xcr, consistent with its lower long-range precision. The dense clusters of blue true positives in 5ahw, compared to the sparser, fragmented patterns in 1xcr, clearly illustrate the variation in predictive performance across proteins."
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": ".venv",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.11.13"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 5
|
|||
|
}
|