Formatted code

This commit is contained in:
Vincent Amato
2025-08-15 23:54:53 -04:00
parent eccabdd227
commit 98d800866e
10 changed files with 200 additions and 127 deletions

View File

@@ -26,7 +26,7 @@ inputs = tokenizer(
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024
max_length=1024,
)
input_ids = inputs["input_ids"].to("mps")
attention_mask = inputs["attention_mask"].to("mps")

View File

@@ -8,6 +8,7 @@ import mlx.core as mx
import torch
from huggingface_hub import snapshot_download
def download(hf_repo: str) -> Path:
"""Download model from Hugging Face."""
return Path(
@@ -17,6 +18,7 @@ def download(hf_repo: str) -> Path:
)
)
def remap_key(key: str) -> str:
"""Remap HuggingFace ESM key names to MLX format."""
@@ -112,7 +114,9 @@ def convert(model_path: Path) -> Dict[str, mx.array]:
for key, value in weights.items():
mlx_key = remap_key(key)
if mlx_key is not None:
mlx_weights[mlx_key] = mx.array(value) if not isinstance(value, mx.array) else value
mlx_weights[mlx_key] = (
mx.array(value) if not isinstance(value, mx.array) else value
)
# If lm_head.weight is missing but embed_tokens.weight exists, set up weight sharing
# (This is for smaller models that don't have a separate lm_head.decoder.weight)
@@ -125,19 +129,13 @@ def convert(model_path: Path) -> Dict[str, mx.array]:
def main():
parser = argparse.ArgumentParser(description="Convert ESM weights to MLX format")
parser.add_argument(
"--hf-path",
default="facebook/esm2_t6_8M_UR50D",
help="Hugging Face model path"
)
parser.add_argument(
"--mlx-path",
default=None,
help="Output path for MLX model"
"--hf-path", default="facebook/esm2_t6_8M_UR50D", help="Hugging Face model path"
)
parser.add_argument("--mlx-path", default=None, help="Output path for MLX model")
parser.add_argument(
"--checkpoints-dir",
default="checkpoints",
help="Directory to save checkpoints (default: checkpoints)"
help="Directory to save checkpoints (default: checkpoints)",
)
args = parser.parse_args()
@@ -174,5 +172,6 @@ def main():
print(f"Conversion complete! MLX model saved to {mlx_path}")
if __name__ == "__main__":
main()

View File

@@ -2,18 +2,18 @@
ESM-2 protein language model implementation in MLX
"""
from .model import ESM2
from .tokenizer import ProteinTokenizer
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
from .attention import MultiheadAttention
from .model import ESM2
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
from .rotary_embedding import RotaryEmbedding
from .tokenizer import ProteinTokenizer
__all__ = [
'ESM2',
'ProteinTokenizer',
'ContactPredictionHead',
'RobertaLMHead',
'TransformerLayer',
'MultiheadAttention',
'RotaryEmbedding'
"ESM2",
"ProteinTokenizer",
"ContactPredictionHead",
"RobertaLMHead",
"TransformerLayer",
"MultiheadAttention",
"RotaryEmbedding",
]

View File

@@ -5,6 +5,7 @@ import mlx.nn as nn
from .rotary_embedding import RotaryEmbedding
class MultiheadAttention(nn.Module):
"""
Multi-head attention layer with rotary position embeddings, as used in ESM-2.
@@ -114,7 +115,9 @@ class MultiheadAttention(nn.Module):
attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len)
# Convert key_padding_mask to boolean and expand dimensions
# key_padding_mask: [bsz, src_len] -> [bsz, 1, 1, src_len]
mask = mx.expand_dims(mx.expand_dims(key_padding_mask.astype(mx.bool_), 1), 2)
mask = mx.expand_dims(
mx.expand_dims(key_padding_mask.astype(mx.bool_), 1), 2
)
# Apply mask: set attention to -inf where mask is True (padded positions)
attn_weights = mx.where(mask, -mx.inf, attn_weights)
attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len)
@@ -135,16 +138,16 @@ class MultiheadAttention(nn.Module):
attn_weights_out: Optional[mx.array] = None
if need_head_weights:
# Return attention weights for each head separately
attn_weights_out = attn_weights_float.reshape(
bsz, self.num_heads, tgt_len, src_len
).astype(attn.dtype).swapaxes(0, 1)
attn_weights_out = (
attn_weights_float.reshape(bsz, self.num_heads, tgt_len, src_len)
.astype(attn.dtype)
.swapaxes(0, 1)
)
else:
# Return averaged attention weights
attn_weights_out = mx.mean(
attn_weights_float.reshape(bsz, self.num_heads, tgt_len, src_len),
axis=1
axis=1,
).astype(attn.dtype)
return attn, attn_weights_out

View File

@@ -1,12 +1,12 @@
from typing import List, Dict, Optional, Tuple
from pathlib import Path
import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .tokenizer import ProteinTokenizer
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
from .tokenizer import ProteinTokenizer
class ESM2(nn.Module):
@@ -128,8 +128,12 @@ class ESM2(nn.Module):
mask_ratio_train = 0.15 * 0.8
src_lengths = mx.sum(~padding_mask, axis=-1, keepdims=True)
mask_ratio_observed = mx.sum(mask_positions, axis=-1, keepdims=True) / src_lengths
scale_factor = (1 - mask_ratio_train) / mx.maximum(1 - mask_ratio_observed, 1e-8)
mask_ratio_observed = (
mx.sum(mask_positions, axis=-1, keepdims=True) / src_lengths
)
scale_factor = (1 - mask_ratio_train) / mx.maximum(
1 - mask_ratio_observed, 1e-8
)
x = x * scale_factor[:, None, :]
# Zero out padding positions
@@ -194,7 +198,9 @@ class ESM2(nn.Module):
# Mask out padded positions if present
if padding_mask is not None:
attention_mask = 1 - padding_mask.astype(attentions.dtype)
attention_mask = mx.expand_dims(attention_mask, 1) * mx.expand_dims(attention_mask, 2)
attention_mask = mx.expand_dims(attention_mask, 1) * mx.expand_dims(
attention_mask, 2
)
attentions = attentions * attention_mask[:, None, None, :, :]
result["attentions"] = attentions

View File

@@ -58,7 +58,9 @@ class RotaryEmbedding(nn.Module):
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_tables(self, x: mx.array, seq_dimension: int = 1) -> Tuple[mx.array, mx.array]:
def _update_cos_sin_tables(
self, x: mx.array, seq_dimension: int = 1
) -> Tuple[mx.array, mx.array]:
"""
Compute and cache cos/sin tables for the given sequence length.

View File

@@ -1,14 +1,38 @@
from typing import List, Sequence, Union, Optional
import json
from pathlib import Path
from typing import List, Optional, Sequence, Union
import mlx.core as mx
# Canonical amino-acid tokens (IUPAC standard + uncommon variants)
PROTEIN_TOKENS = [
"L", "A", "G", "V", "S", "E", "R", "T", "I", "D",
"P", "K", "Q", "N", "F", "Y", "M", "H", "W", "C",
"X", "B", "U", "Z", "O", ".", "-"
"L",
"A",
"G",
"V",
"S",
"E",
"R",
"T",
"I",
"D",
"P",
"K",
"Q",
"N",
"F",
"Y",
"M",
"H",
"W",
"C",
"X",
"B",
"U",
"Z",
"O",
".",
"-",
]
ArrayLike = Union[List[int], mx.array]
@@ -201,7 +225,11 @@ class ProteinTokenizer:
for i in ids_list:
tok = self.id_to_token.get(i, "<unk>")
if skip_special_tokens and tok in {
"<cls>", "<pad>", "<eos>", "<unk>", "<mask>"
"<cls>",
"<pad>",
"<eos>",
"<unk>",
"<mask>",
}:
continue
toks.append(tok)

View File

@@ -1,25 +1,27 @@
import argparse
import mlx.core as mx
from esm import ESM2
def main():
parser = argparse.ArgumentParser(description="ESM-2 MLX Inference")
parser.add_argument(
"--model-path",
default="checkpoints/mlx-esm2_t33_650M_UR50D",
help="Path to MLX model checkpoint"
help="Path to MLX model checkpoint",
)
parser.add_argument(
"--sequence",
default="MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
help="Protein sequence to test (default: human insulin)"
help="Protein sequence to test (default: human insulin)",
)
parser.add_argument(
"--mask-position",
type=int,
default=None,
help="Position to mask (default: middle of sequence)"
help="Position to mask (default: middle of sequence)",
)
args = parser.parse_args()
@@ -28,7 +30,9 @@ def main():
# Determine sequence and position to mask
sequence = args.sequence.upper()
mask_pos = args.mask_position if args.mask_position is not None else len(sequence) // 2
mask_pos = (
args.mask_position if args.mask_position is not None else len(sequence) // 2
)
if mask_pos >= len(sequence):
mask_pos = len(sequence) - 1
original_aa = sequence[mask_pos] # The original amino acid at masked position
@@ -40,15 +44,25 @@ def main():
# Tokenize sequence before and after the mask
before = tokenizer.encode(sequence[:mask_pos], add_special_tokens=False)
after = tokenizer.encode(sequence[mask_pos+1:], add_special_tokens=False)
after = tokenizer.encode(sequence[mask_pos + 1 :], add_special_tokens=False)
# Build token sequence with <cls>, <mask>, and <eos>
tokens = mx.array([[tokenizer.cls_id] + before.tolist() + [tokenizer.mask_id] + after.tolist() + [tokenizer.eos_id]])
tokens = mx.array(
[
[tokenizer.cls_id]
+ before.tolist()
+ [tokenizer.mask_id]
+ after.tolist()
+ [tokenizer.eos_id]
]
)
mask_token_pos = 1 + len(before) # Position of <mask> token
# Run model to get logits for each token position
logits = model(tokens)["logits"]
probs = mx.softmax(logits[0, mask_token_pos, :]) # Softmax over vocabulary at mask position
probs = mx.softmax(
logits[0, mask_token_pos, :]
) # Softmax over vocabulary at mask position
# Get top-5 most likely tokens
top_indices = mx.argsort(probs)[-5:][::-1]
@@ -62,5 +76,6 @@ def main():
marker = "" if token == original_aa else " "
print(f"{marker} {i+1}. {token}: {prob:.3f} ({prob*100:.1f}%)")
if __name__ == "__main__":
main()

View File

@@ -1,6 +1,7 @@
import unittest
import numpy as np
from transformers import AutoTokenizer, EsmForMaskedLM, EsmConfig
from transformers import AutoTokenizer, EsmConfig, EsmForMaskedLM
from esm import ESM2
@@ -8,22 +9,23 @@ from esm import ESM2
MLX_PATH = "checkpoints/mlx-esm2_t12_35M_UR50D"
HF_PATH = "facebook/esm2_t12_35M_UR50D"
def load_mlx_model():
"""Load MLX ESM-2 model and tokenizer."""
tokenizer, model = ESM2.from_pretrained(MLX_PATH)
return tokenizer, model
def load_hf_model():
"""Load Hugging Face ESM-2 model and tokenizer with hidden states + attentions."""
tokenizer = AutoTokenizer.from_pretrained(HF_PATH)
config = EsmConfig.from_pretrained(
HF_PATH,
output_hidden_states=True,
output_attentions=True
HF_PATH, output_hidden_states=True, output_attentions=True
)
model = EsmForMaskedLM.from_pretrained(HF_PATH, config=config)
return tokenizer, model
class TestESM2(unittest.TestCase):
@classmethod
def setUpClass(cls):
@@ -37,12 +39,16 @@ class TestESM2(unittest.TestCase):
sequences = [
"MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
]
# Compare batched tokenization (padded sequences)
mlx_batch = self.mlx_tokenizer.batch_encode(sequences)
hf_batch = self.hf_tokenizer(sequences, return_tensors="pt", padding=True)["input_ids"].cpu().numpy()
hf_batch = (
self.hf_tokenizer(sequences, return_tensors="pt", padding=True)["input_ids"]
.cpu()
.numpy()
)
self.assertEqual(tuple(mlx_batch.shape), tuple(hf_batch.shape))
self.assertTrue(
np.array_equal(np.array(mlx_batch.tolist(), dtype=hf_batch.dtype), hf_batch)
@@ -51,22 +57,29 @@ class TestESM2(unittest.TestCase):
# Compare single-sequence encode/decode
for sequence in sequences:
mlx_tokens = self.mlx_tokenizer.encode(sequence)
hf_tokens = self.hf_tokenizer(sequence, return_tensors="pt")["input_ids"].cpu().numpy().tolist()[0]
hf_tokens = (
self.hf_tokenizer(sequence, return_tensors="pt")["input_ids"]
.cpu()
.numpy()
.tolist()[0]
)
self.assertTrue(np.array_equal(mlx_tokens, hf_tokens))
self.assertEqual(
self.mlx_tokenizer.decode(mlx_tokens),
self.hf_tokenizer.decode(hf_tokens).replace(" ", "")
self.hf_tokenizer.decode(hf_tokens).replace(" ", ""),
)
self.assertEqual(
self.mlx_tokenizer.decode(mlx_tokens, skip_special_tokens=True),
self.hf_tokenizer.decode(hf_tokens, skip_special_tokens=True).replace(" ", "")
self.hf_tokenizer.decode(hf_tokens, skip_special_tokens=True).replace(
" ", ""
),
)
def test_model(self):
"""Verify MLX and HF model outputs match (logits, hidden states, attentions)."""
sequences = [
"MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
]
for sequence in sequences:
# Tokenize
@@ -77,7 +90,7 @@ class TestESM2(unittest.TestCase):
mlx_outputs = self.mlx_model(
mlx_tokens,
repr_layers=[self.mlx_model.num_layers],
need_head_weights=True
need_head_weights=True,
)
hf_outputs = self.hf_model(input_ids=hf_tokens)
@@ -90,12 +103,19 @@ class TestESM2(unittest.TestCase):
final_layer = self.mlx_model.num_layers
mlx_hidden_states = np.array(mlx_outputs["representations"][final_layer])
hf_hidden_states = hf_outputs["hidden_states"][-1].detach().cpu().numpy()
self.assertTrue(np.allclose(mlx_hidden_states, hf_hidden_states, rtol=1e-4, atol=1e-4))
self.assertTrue(
np.allclose(mlx_hidden_states, hf_hidden_states, rtol=1e-4, atol=1e-4)
)
# Compare attentions for final layer
mlx_attentions = np.array(mlx_outputs["attentions"][:, final_layer-1, :, :, :])
mlx_attentions = np.array(
mlx_outputs["attentions"][:, final_layer - 1, :, :, :]
)
hf_attentions = hf_outputs["attentions"][-1].detach().cpu().numpy()
self.assertTrue(np.allclose(mlx_attentions, hf_attentions, rtol=1e-4, atol=1e-4))
self.assertTrue(
np.allclose(mlx_attentions, hf_attentions, rtol=1e-4, atol=1e-4)
)
if __name__ == "__main__":
unittest.main()