Formatted code

This commit is contained in:
Vincent Amato
2025-08-15 23:54:53 -04:00
parent eccabdd227
commit 98d800866e
10 changed files with 200 additions and 127 deletions

View File

@@ -44,4 +44,4 @@ throughput = batch_size * 1000 / ms_per_step
# Display results # Display results
print(f"Time (ms) per step: {ms_per_step:.3f}") print(f"Time (ms) per step: {ms_per_step:.3f}")
print(f"Throughput: {throughput:.2f} sequences/sec") print(f"Throughput: {throughput:.2f} sequences/sec")

View File

@@ -22,11 +22,11 @@ steps = 50
# Tokenize input sequence and replicate for the batch # Tokenize input sequence and replicate for the batch
# Replicate the same sequence 'batch_size' times to create a batch # Replicate the same sequence 'batch_size' times to create a batch
inputs = tokenizer( inputs = tokenizer(
[protein_sequence] * batch_size, [protein_sequence] * batch_size,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
truncation=True, truncation=True,
max_length=1024 max_length=1024,
) )
input_ids = inputs["input_ids"].to("mps") input_ids = inputs["input_ids"].to("mps")
attention_mask = inputs["attention_mask"].to("mps") attention_mask = inputs["attention_mask"].to("mps")
@@ -49,4 +49,4 @@ throughput = batch_size * 1000 / ms_per_step
# Report results # Report results
print(f"Time (ms) per step: {ms_per_step:.3f}") print(f"Time (ms) per step: {ms_per_step:.3f}")
print(f"Throughput: {throughput:.2f} sequences/sec") print(f"Throughput: {throughput:.2f} sequences/sec")

View File

@@ -8,6 +8,7 @@ import mlx.core as mx
import torch import torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
def download(hf_repo: str) -> Path: def download(hf_repo: str) -> Path:
"""Download model from Hugging Face.""" """Download model from Hugging Face."""
return Path( return Path(
@@ -17,13 +18,14 @@ def download(hf_repo: str) -> Path:
) )
) )
def remap_key(key: str) -> str: def remap_key(key: str) -> str:
"""Remap HuggingFace ESM key names to MLX format.""" """Remap HuggingFace ESM key names to MLX format."""
# Skip position embeddings and position_ids # Skip position embeddings and position_ids
if "position_embeddings" in key or "position_ids" in key: if "position_embeddings" in key or "position_ids" in key:
return None return None
# Map lm_head components properly # Map lm_head components properly
if key == "lm_head.decoder.weight": if key == "lm_head.decoder.weight":
return "lm_head.weight" return "lm_head.weight"
@@ -37,14 +39,14 @@ def remap_key(key: str) -> str:
return "lm_head.layer_norm.weight" return "lm_head.layer_norm.weight"
if key == "lm_head.layer_norm.bias": if key == "lm_head.layer_norm.bias":
return "lm_head.layer_norm.bias" return "lm_head.layer_norm.bias"
# Core remapping patterns # Core remapping patterns
key = key.replace("esm.embeddings.word_embeddings", "embed_tokens") key = key.replace("esm.embeddings.word_embeddings", "embed_tokens")
key = key.replace("esm.encoder.emb_layer_norm_after", "emb_layer_norm_after") key = key.replace("esm.encoder.emb_layer_norm_after", "emb_layer_norm_after")
key = key.replace("esm.encoder.layer.", "layer_") key = key.replace("esm.encoder.layer.", "layer_")
key = key.replace("esm.contact_head", "contact_head") key = key.replace("esm.contact_head", "contact_head")
key = key.replace("lm_head", "lm_head") key = key.replace("lm_head", "lm_head")
# Attention patterns # Attention patterns
key = key.replace(".attention.self.", ".self_attn.") key = key.replace(".attention.self.", ".self_attn.")
key = key.replace(".attention.output.dense", ".self_attn.out_proj") key = key.replace(".attention.output.dense", ".self_attn.out_proj")
@@ -53,12 +55,12 @@ def remap_key(key: str) -> str:
key = key.replace(".key", ".k_proj") key = key.replace(".key", ".k_proj")
key = key.replace(".value", ".v_proj") key = key.replace(".value", ".v_proj")
key = key.replace(".rotary_embeddings", ".rot_emb") key = key.replace(".rotary_embeddings", ".rot_emb")
# FFN patterns # FFN patterns
key = key.replace(".intermediate.dense", ".fc1") key = key.replace(".intermediate.dense", ".fc1")
key = key.replace(".output.dense", ".fc2") key = key.replace(".output.dense", ".fc2")
key = key.replace(".LayerNorm", ".final_layer_norm") key = key.replace(".LayerNorm", ".final_layer_norm")
return key return key
@@ -70,24 +72,24 @@ def load_weights(model_path: Path) -> Dict:
if safetensors_path.exists(): if safetensors_path.exists():
print("Loading from safetensors...") print("Loading from safetensors...")
return mx.load(str(safetensors_path)) return mx.load(str(safetensors_path))
# Check for single bin file # Check for single bin file
single_bin_path = model_path / "pytorch_model.bin" single_bin_path = model_path / "pytorch_model.bin"
if single_bin_path.exists(): if single_bin_path.exists():
print("Loading from pytorch_model.bin...") print("Loading from pytorch_model.bin...")
state_dict = torch.load(str(single_bin_path), map_location="cpu") state_dict = torch.load(str(single_bin_path), map_location="cpu")
return {k: v.numpy() for k, v in state_dict.items()} return {k: v.numpy() for k, v in state_dict.items()}
# Check for sharded bin files # Check for sharded bin files
index_file = model_path / "pytorch_model.bin.index.json" index_file = model_path / "pytorch_model.bin.index.json"
if index_file.exists(): if index_file.exists():
print("Loading from sharded bin files...") print("Loading from sharded bin files...")
with open(index_file) as f: with open(index_file) as f:
index = json.load(f) index = json.load(f)
# Get unique shard files # Get unique shard files
shard_files = set(index["weight_map"].values()) shard_files = set(index["weight_map"].values())
# Load all shards # Load all shards
state_dict = {} state_dict = {}
for shard_file in sorted(shard_files): for shard_file in sorted(shard_files):
@@ -95,9 +97,9 @@ def load_weights(model_path: Path) -> Dict:
shard_path = model_path / shard_file shard_path = model_path / shard_file
shard_dict = torch.load(str(shard_path), map_location="cpu") shard_dict = torch.load(str(shard_path), map_location="cpu")
state_dict.update(shard_dict) state_dict.update(shard_dict)
return {k: v.numpy() for k, v in state_dict.items()} return {k: v.numpy() for k, v in state_dict.items()}
raise ValueError(f"No model weights found in {model_path}") raise ValueError(f"No model weights found in {model_path}")
@@ -106,38 +108,34 @@ def convert(model_path: Path) -> Dict[str, mx.array]:
# Load weights from any format # Load weights from any format
weights = load_weights(model_path) weights = load_weights(model_path)
# Convert keys and create MLX arrays # Convert keys and create MLX arrays
mlx_weights = {} mlx_weights = {}
for key, value in weights.items(): for key, value in weights.items():
mlx_key = remap_key(key) mlx_key = remap_key(key)
if mlx_key is not None: if mlx_key is not None:
mlx_weights[mlx_key] = mx.array(value) if not isinstance(value, mx.array) else value mlx_weights[mlx_key] = (
mx.array(value) if not isinstance(value, mx.array) else value
)
# If lm_head.weight is missing but embed_tokens.weight exists, set up weight sharing # If lm_head.weight is missing but embed_tokens.weight exists, set up weight sharing
# (This is for smaller models that don't have a separate lm_head.decoder.weight) # (This is for smaller models that don't have a separate lm_head.decoder.weight)
if "lm_head.weight" not in mlx_weights and "embed_tokens.weight" in mlx_weights: if "lm_head.weight" not in mlx_weights and "embed_tokens.weight" in mlx_weights:
mlx_weights["lm_head.weight"] = mlx_weights["embed_tokens.weight"] mlx_weights["lm_head.weight"] = mlx_weights["embed_tokens.weight"]
return mlx_weights return mlx_weights
def main(): def main():
parser = argparse.ArgumentParser(description="Convert ESM weights to MLX format") parser = argparse.ArgumentParser(description="Convert ESM weights to MLX format")
parser.add_argument( parser.add_argument(
"--hf-path", "--hf-path", default="facebook/esm2_t6_8M_UR50D", help="Hugging Face model path"
default="facebook/esm2_t6_8M_UR50D",
help="Hugging Face model path"
)
parser.add_argument(
"--mlx-path",
default=None,
help="Output path for MLX model"
) )
parser.add_argument("--mlx-path", default=None, help="Output path for MLX model")
parser.add_argument( parser.add_argument(
"--checkpoints-dir", "--checkpoints-dir",
default="checkpoints", default="checkpoints",
help="Directory to save checkpoints (default: checkpoints)" help="Directory to save checkpoints (default: checkpoints)",
) )
args = parser.parse_args() args = parser.parse_args()
@@ -145,7 +143,7 @@ def main():
# Download model # Download model
print(f"Downloading {args.hf_path}...") print(f"Downloading {args.hf_path}...")
model_path = download(args.hf_path) model_path = download(args.hf_path)
# Set output path # Set output path
if args.mlx_path is None: if args.mlx_path is None:
model_name = args.hf_path.split("/")[-1] model_name = args.hf_path.split("/")[-1]
@@ -154,25 +152,26 @@ def main():
args.mlx_path = checkpoints_dir / f"mlx-{model_name}" args.mlx_path = checkpoints_dir / f"mlx-{model_name}"
mlx_path = Path(args.mlx_path) mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True) mlx_path.mkdir(parents=True, exist_ok=True)
# Convert weights # Convert weights
print("Converting weights...") print("Converting weights...")
mlx_weights = convert(model_path) mlx_weights = convert(model_path)
# Save weights # Save weights
print(f"Saving MLX weights to {mlx_path}...") print(f"Saving MLX weights to {mlx_path}...")
mx.save_safetensors(str(mlx_path / "model.safetensors"), mlx_weights) mx.save_safetensors(str(mlx_path / "model.safetensors"), mlx_weights)
# Copy config and other files # Copy config and other files
print("Copying config...") print("Copying config...")
shutil.copy(model_path / "config.json", mlx_path / "config.json") shutil.copy(model_path / "config.json", mlx_path / "config.json")
for file_name in ["special_tokens_map.json", "tokenizer.json", "vocab.txt"]: for file_name in ["special_tokens_map.json", "tokenizer.json", "vocab.txt"]:
src_file = model_path / file_name src_file = model_path / file_name
if src_file.exists(): if src_file.exists():
shutil.copy(src_file, mlx_path / file_name) shutil.copy(src_file, mlx_path / file_name)
print(f"Conversion complete! MLX model saved to {mlx_path}") print(f"Conversion complete! MLX model saved to {mlx_path}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -2,18 +2,18 @@
ESM-2 protein language model implementation in MLX ESM-2 protein language model implementation in MLX
""" """
from .model import ESM2
from .tokenizer import ProteinTokenizer
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
from .attention import MultiheadAttention from .attention import MultiheadAttention
from .model import ESM2
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
from .rotary_embedding import RotaryEmbedding from .rotary_embedding import RotaryEmbedding
from .tokenizer import ProteinTokenizer
__all__ = [ __all__ = [
'ESM2', "ESM2",
'ProteinTokenizer', "ProteinTokenizer",
'ContactPredictionHead', "ContactPredictionHead",
'RobertaLMHead', "RobertaLMHead",
'TransformerLayer', "TransformerLayer",
'MultiheadAttention', "MultiheadAttention",
'RotaryEmbedding' "RotaryEmbedding",
] ]

View File

@@ -5,14 +5,15 @@ import mlx.nn as nn
from .rotary_embedding import RotaryEmbedding from .rotary_embedding import RotaryEmbedding
class MultiheadAttention(nn.Module): class MultiheadAttention(nn.Module):
""" """
Multi-head attention layer with rotary position embeddings, as used in ESM-2. Multi-head attention layer with rotary position embeddings, as used in ESM-2.
This module implements both self-attention (when `key` and `value` are not This module implements both self-attention (when `key` and `value` are not
provided) and cross-attention. It projects input sequences into queries, provided) and cross-attention. It projects input sequences into queries,
keys, and values, applies rotary position embeddings to encode relative keys, and values, applies rotary position embeddings to encode relative
position information, computes scaled dot-product attention over multiple position information, computes scaled dot-product attention over multiple
heads in parallel, and returns a combined output projection. heads in parallel, and returns a combined output projection.
Args: Args:
@@ -56,7 +57,7 @@ class MultiheadAttention(nn.Module):
) -> Tuple[mx.array, Optional[mx.array]]: ) -> Tuple[mx.array, Optional[mx.array]]:
""" """
Multi-head attention forward pass. Multi-head attention forward pass.
Args: Args:
query: Tensor of shape (tgt_len, batch, embed_dim). query: Tensor of shape (tgt_len, batch, embed_dim).
key: Optional tensor of shape (src_len, batch, embed_dim). Defaults to `query`. key: Optional tensor of shape (src_len, batch, embed_dim). Defaults to `query`.
@@ -67,32 +68,32 @@ class MultiheadAttention(nn.Module):
Returns: Returns:
attn_output: Tensor of shape (tgt_len, batch, embed_dim). attn_output: Tensor of shape (tgt_len, batch, embed_dim).
attn_weights_out: Attention weights of shape attn_weights_out: Attention weights of shape
(num_heads, batch, tgt_len, src_len) if per-head, (num_heads, batch, tgt_len, src_len) if per-head,
or (batch, tgt_len, src_len) if averaged. or (batch, tgt_len, src_len) if averaged.
""" """
tgt_len, bsz, embed_dim = query.shape tgt_len, bsz, embed_dim = query.shape
assert embed_dim == self.embed_dim assert embed_dim == self.embed_dim
# For self-attention, use query as key and value if not provided # For self-attention, use query as key and value if not provided
if key is None: if key is None:
key = query key = query
if value is None: if value is None:
value = query value = query
# Project queries, keys, values # Project queries, keys, values
q = self.q_proj(query) q = self.q_proj(query)
k = self.k_proj(key) k = self.k_proj(key)
v = self.v_proj(value) v = self.v_proj(value)
q = q * self.scaling q = q * self.scaling
# Reshape for multi-head attention # Reshape for multi-head attention
q = q.reshape(tgt_len, bsz * self.num_heads, self.head_dim).swapaxes(0, 1) q = q.reshape(tgt_len, bsz * self.num_heads, self.head_dim).swapaxes(0, 1)
k = k.reshape(-1, bsz * self.num_heads, self.head_dim).swapaxes(0, 1) k = k.reshape(-1, bsz * self.num_heads, self.head_dim).swapaxes(0, 1)
v = v.reshape(-1, bsz * self.num_heads, self.head_dim).swapaxes(0, 1) v = v.reshape(-1, bsz * self.num_heads, self.head_dim).swapaxes(0, 1)
src_len = k.shape[1] src_len = k.shape[1]
# Apply rotary embeddings if present # Apply rotary embeddings if present
@@ -114,7 +115,9 @@ class MultiheadAttention(nn.Module):
attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len)
# Convert key_padding_mask to boolean and expand dimensions # Convert key_padding_mask to boolean and expand dimensions
# key_padding_mask: [bsz, src_len] -> [bsz, 1, 1, src_len] # key_padding_mask: [bsz, src_len] -> [bsz, 1, 1, src_len]
mask = mx.expand_dims(mx.expand_dims(key_padding_mask.astype(mx.bool_), 1), 2) mask = mx.expand_dims(
mx.expand_dims(key_padding_mask.astype(mx.bool_), 1), 2
)
# Apply mask: set attention to -inf where mask is True (padded positions) # Apply mask: set attention to -inf where mask is True (padded positions)
attn_weights = mx.where(mask, -mx.inf, attn_weights) attn_weights = mx.where(mask, -mx.inf, attn_weights)
attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len)
@@ -126,25 +129,25 @@ class MultiheadAttention(nn.Module):
# Compute attention output # Compute attention output
attn = attn_probs @ v attn = attn_probs @ v
assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
# Reshape output # Reshape output
attn = attn.swapaxes(0, 1).reshape(tgt_len, bsz, embed_dim) attn = attn.swapaxes(0, 1).reshape(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn) attn = self.out_proj(attn)
# Return attention weights if requested # Return attention weights if requested
attn_weights_out: Optional[mx.array] = None attn_weights_out: Optional[mx.array] = None
if need_head_weights: if need_head_weights:
# Return attention weights for each head separately # Return attention weights for each head separately
attn_weights_out = attn_weights_float.reshape( attn_weights_out = (
bsz, self.num_heads, tgt_len, src_len attn_weights_float.reshape(bsz, self.num_heads, tgt_len, src_len)
).astype(attn.dtype).swapaxes(0, 1) .astype(attn.dtype)
.swapaxes(0, 1)
)
else: else:
# Return averaged attention weights # Return averaged attention weights
attn_weights_out = mx.mean( attn_weights_out = mx.mean(
attn_weights_float.reshape(bsz, self.num_heads, tgt_len, src_len), attn_weights_float.reshape(bsz, self.num_heads, tgt_len, src_len),
axis=1 axis=1,
).astype(attn.dtype) ).astype(attn.dtype)
return attn, attn_weights_out return attn, attn_weights_out

View File

@@ -1,12 +1,12 @@
from typing import List, Dict, Optional, Tuple
from pathlib import Path
import json import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .tokenizer import ProteinTokenizer
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
from .tokenizer import ProteinTokenizer
class ESM2(nn.Module): class ESM2(nn.Module):
@@ -40,7 +40,7 @@ class ESM2(nn.Module):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.vocab_size = len(tokenizer) self.vocab_size = len(tokenizer)
# Special token IDs / config # Special token IDs / config
self.padding_idx = tokenizer.pad_id self.padding_idx = tokenizer.pad_id
self.mask_idx = tokenizer.mask_id self.mask_idx = tokenizer.mask_id
self.cls_idx = tokenizer.cls_id self.cls_idx = tokenizer.cls_id
@@ -128,8 +128,12 @@ class ESM2(nn.Module):
mask_ratio_train = 0.15 * 0.8 mask_ratio_train = 0.15 * 0.8
src_lengths = mx.sum(~padding_mask, axis=-1, keepdims=True) src_lengths = mx.sum(~padding_mask, axis=-1, keepdims=True)
mask_ratio_observed = mx.sum(mask_positions, axis=-1, keepdims=True) / src_lengths mask_ratio_observed = (
scale_factor = (1 - mask_ratio_train) / mx.maximum(1 - mask_ratio_observed, 1e-8) mx.sum(mask_positions, axis=-1, keepdims=True) / src_lengths
)
scale_factor = (1 - mask_ratio_train) / mx.maximum(
1 - mask_ratio_observed, 1e-8
)
x = x * scale_factor[:, None, :] x = x * scale_factor[:, None, :]
# Zero out padding positions # Zero out padding positions
@@ -194,7 +198,9 @@ class ESM2(nn.Module):
# Mask out padded positions if present # Mask out padded positions if present
if padding_mask is not None: if padding_mask is not None:
attention_mask = 1 - padding_mask.astype(attentions.dtype) attention_mask = 1 - padding_mask.astype(attentions.dtype)
attention_mask = mx.expand_dims(attention_mask, 1) * mx.expand_dims(attention_mask, 2) attention_mask = mx.expand_dims(attention_mask, 1) * mx.expand_dims(
attention_mask, 2
)
attentions = attentions * attention_mask[:, None, None, :, :] attentions = attentions * attention_mask[:, None, None, :, :]
result["attentions"] = attentions result["attentions"] = attentions
@@ -304,7 +310,7 @@ class ESM2(nn.Module):
# Check for vocab and special tokens files # Check for vocab and special tokens files
vocab_path = model_dir / "vocab.txt" vocab_path = model_dir / "vocab.txt"
special_tokens_path = model_dir / "special_tokens_map.json" special_tokens_path = model_dir / "special_tokens_map.json"
if vocab_path.exists() and special_tokens_path.exists(): if vocab_path.exists() and special_tokens_path.exists():
tokenizer = ProteinTokenizer( tokenizer = ProteinTokenizer(
vocab_file=str(vocab_path), vocab_file=str(vocab_path),
@@ -333,4 +339,4 @@ class ESM2(nn.Module):
cur[parts[-1]] = value cur[parts[-1]] = value
model.update(nested_weights) model.update(nested_weights)
return tokenizer, model return tokenizer, model

View File

@@ -58,7 +58,9 @@ class RotaryEmbedding(nn.Module):
self._cos_cached = None self._cos_cached = None
self._sin_cached = None self._sin_cached = None
def _update_cos_sin_tables(self, x: mx.array, seq_dimension: int = 1) -> Tuple[mx.array, mx.array]: def _update_cos_sin_tables(
self, x: mx.array, seq_dimension: int = 1
) -> Tuple[mx.array, mx.array]:
""" """
Compute and cache cos/sin tables for the given sequence length. Compute and cache cos/sin tables for the given sequence length.
@@ -109,4 +111,4 @@ class RotaryEmbedding(nn.Module):
return ( return (
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
) )

View File

@@ -1,14 +1,38 @@
from typing import List, Sequence, Union, Optional
import json import json
from pathlib import Path from pathlib import Path
from typing import List, Optional, Sequence, Union
import mlx.core as mx import mlx.core as mx
# Canonical amino-acid tokens (IUPAC standard + uncommon variants) # Canonical amino-acid tokens (IUPAC standard + uncommon variants)
PROTEIN_TOKENS = [ PROTEIN_TOKENS = [
"L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "L",
"P", "K", "Q", "N", "F", "Y", "M", "H", "W", "C", "A",
"X", "B", "U", "Z", "O", ".", "-" "G",
"V",
"S",
"E",
"R",
"T",
"I",
"D",
"P",
"K",
"Q",
"N",
"F",
"Y",
"M",
"H",
"W",
"C",
"X",
"B",
"U",
"Z",
"O",
".",
"-",
] ]
ArrayLike = Union[List[int], mx.array] ArrayLike = Union[List[int], mx.array]
@@ -39,7 +63,7 @@ class ProteinTokenizer:
If both files are provided, they override the default vocabulary and If both files are provided, they override the default vocabulary and
special token mappings. Otherwise, defaults are loaded. special token mappings. Otherwise, defaults are loaded.
""" """
# Load vocabulary from files if given, otherwise use built-in defaults # Load vocabulary from files if given, otherwise use built-in defaults
if vocab_file and special_tokens_map_file: if vocab_file and special_tokens_map_file:
self._load_from_files(vocab_file, special_tokens_map_file) self._load_from_files(vocab_file, special_tokens_map_file)
@@ -51,15 +75,15 @@ class ProteinTokenizer:
self.id_to_token = {i: tok for i, tok in enumerate(self.vocab)} self.id_to_token = {i: tok for i, tok in enumerate(self.vocab)}
# Cache special token IDs # Cache special token IDs
self.cls_id = self.token_to_id["<cls>"] self.cls_id = self.token_to_id["<cls>"]
self.pad_id = self.token_to_id["<pad>"] self.pad_id = self.token_to_id["<pad>"]
self.eos_id = self.token_to_id["<eos>"] self.eos_id = self.token_to_id["<eos>"]
self.unk_id = self.token_to_id["<unk>"] self.unk_id = self.token_to_id["<unk>"]
self.mask_id = self.token_to_id["<mask>"] self.mask_id = self.token_to_id["<mask>"]
# Behavior flags for ESM-2-style BOS/EOS # Behavior flags for ESM-2-style BOS/EOS
self.prepend_bos = True self.prepend_bos = True
self.append_eos = True self.append_eos = True
def _load_from_files(self, vocab_file: str, special_tokens_map_file: str) -> None: def _load_from_files(self, vocab_file: str, special_tokens_map_file: str) -> None:
"""Load vocabulary and special tokens from the provided files.""" """Load vocabulary and special tokens from the provided files."""
@@ -201,7 +225,11 @@ class ProteinTokenizer:
for i in ids_list: for i in ids_list:
tok = self.id_to_token.get(i, "<unk>") tok = self.id_to_token.get(i, "<unk>")
if skip_special_tokens and tok in { if skip_special_tokens and tok in {
"<cls>", "<pad>", "<eos>", "<unk>", "<mask>" "<cls>",
"<pad>",
"<eos>",
"<unk>",
"<mask>",
}: }:
continue continue
toks.append(tok) toks.append(tok)

View File

@@ -1,58 +1,72 @@
import argparse import argparse
import mlx.core as mx import mlx.core as mx
from esm import ESM2 from esm import ESM2
def main(): def main():
parser = argparse.ArgumentParser(description="ESM-2 MLX Inference") parser = argparse.ArgumentParser(description="ESM-2 MLX Inference")
parser.add_argument( parser.add_argument(
"--model-path", "--model-path",
default="checkpoints/mlx-esm2_t33_650M_UR50D", default="checkpoints/mlx-esm2_t33_650M_UR50D",
help="Path to MLX model checkpoint" help="Path to MLX model checkpoint",
) )
parser.add_argument( parser.add_argument(
"--sequence", "--sequence",
default="MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN", default="MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
help="Protein sequence to test (default: human insulin)" help="Protein sequence to test (default: human insulin)",
) )
parser.add_argument( parser.add_argument(
"--mask-position", "--mask-position",
type=int, type=int,
default=None, default=None,
help="Position to mask (default: middle of sequence)" help="Position to mask (default: middle of sequence)",
) )
args = parser.parse_args() args = parser.parse_args()
# Load pretrained ESM-2 model and tokenizer # Load pretrained ESM-2 model and tokenizer
tokenizer, model = ESM2.from_pretrained(args.model_path) tokenizer, model = ESM2.from_pretrained(args.model_path)
# Determine sequence and position to mask # Determine sequence and position to mask
sequence = args.sequence.upper() sequence = args.sequence.upper()
mask_pos = args.mask_position if args.mask_position is not None else len(sequence) // 2 mask_pos = (
args.mask_position if args.mask_position is not None else len(sequence) // 2
)
if mask_pos >= len(sequence): if mask_pos >= len(sequence):
mask_pos = len(sequence) - 1 mask_pos = len(sequence) - 1
original_aa = sequence[mask_pos] # The original amino acid at masked position original_aa = sequence[mask_pos] # The original amino acid at masked position
# Display input info # Display input info
print(f"Original sequence: {sequence}") print(f"Original sequence: {sequence}")
print(f"Masked sequence: {sequence[:mask_pos]}<mask>{sequence[mask_pos+1:]}") print(f"Masked sequence: {sequence[:mask_pos]}<mask>{sequence[mask_pos+1:]}")
print(f"Predicting position {mask_pos}: {original_aa}\n") print(f"Predicting position {mask_pos}: {original_aa}\n")
# Tokenize sequence before and after the mask # Tokenize sequence before and after the mask
before = tokenizer.encode(sequence[:mask_pos], add_special_tokens=False) before = tokenizer.encode(sequence[:mask_pos], add_special_tokens=False)
after = tokenizer.encode(sequence[mask_pos+1:], add_special_tokens=False) after = tokenizer.encode(sequence[mask_pos + 1 :], add_special_tokens=False)
# Build token sequence with <cls>, <mask>, and <eos> # Build token sequence with <cls>, <mask>, and <eos>
tokens = mx.array([[tokenizer.cls_id] + before.tolist() + [tokenizer.mask_id] + after.tolist() + [tokenizer.eos_id]]) tokens = mx.array(
[
[tokenizer.cls_id]
+ before.tolist()
+ [tokenizer.mask_id]
+ after.tolist()
+ [tokenizer.eos_id]
]
)
mask_token_pos = 1 + len(before) # Position of <mask> token mask_token_pos = 1 + len(before) # Position of <mask> token
# Run model to get logits for each token position # Run model to get logits for each token position
logits = model(tokens)["logits"] logits = model(tokens)["logits"]
probs = mx.softmax(logits[0, mask_token_pos, :]) # Softmax over vocabulary at mask position probs = mx.softmax(
logits[0, mask_token_pos, :]
) # Softmax over vocabulary at mask position
# Get top-5 most likely tokens # Get top-5 most likely tokens
top_indices = mx.argsort(probs)[-5:][::-1] top_indices = mx.argsort(probs)[-5:][::-1]
# Print predictions # Print predictions
print("Top predictions:") print("Top predictions:")
for i, idx in enumerate(top_indices): for i, idx in enumerate(top_indices):
@@ -62,5 +76,6 @@ def main():
marker = "" if token == original_aa else " " marker = "" if token == original_aa else " "
print(f"{marker} {i+1}. {token}: {prob:.3f} ({prob*100:.1f}%)") print(f"{marker} {i+1}. {token}: {prob:.3f} ({prob*100:.1f}%)")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -1,6 +1,7 @@
import unittest import unittest
import numpy as np import numpy as np
from transformers import AutoTokenizer, EsmForMaskedLM, EsmConfig from transformers import AutoTokenizer, EsmConfig, EsmForMaskedLM
from esm import ESM2 from esm import ESM2
@@ -8,65 +9,77 @@ from esm import ESM2
MLX_PATH = "checkpoints/mlx-esm2_t12_35M_UR50D" MLX_PATH = "checkpoints/mlx-esm2_t12_35M_UR50D"
HF_PATH = "facebook/esm2_t12_35M_UR50D" HF_PATH = "facebook/esm2_t12_35M_UR50D"
def load_mlx_model(): def load_mlx_model():
"""Load MLX ESM-2 model and tokenizer.""" """Load MLX ESM-2 model and tokenizer."""
tokenizer, model = ESM2.from_pretrained(MLX_PATH) tokenizer, model = ESM2.from_pretrained(MLX_PATH)
return tokenizer, model return tokenizer, model
def load_hf_model(): def load_hf_model():
"""Load Hugging Face ESM-2 model and tokenizer with hidden states + attentions.""" """Load Hugging Face ESM-2 model and tokenizer with hidden states + attentions."""
tokenizer = AutoTokenizer.from_pretrained(HF_PATH) tokenizer = AutoTokenizer.from_pretrained(HF_PATH)
config = EsmConfig.from_pretrained( config = EsmConfig.from_pretrained(
HF_PATH, HF_PATH, output_hidden_states=True, output_attentions=True
output_hidden_states=True,
output_attentions=True
) )
model = EsmForMaskedLM.from_pretrained(HF_PATH, config=config) model = EsmForMaskedLM.from_pretrained(HF_PATH, config=config)
return tokenizer, model return tokenizer, model
class TestESM2(unittest.TestCase): class TestESM2(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
# Load both MLX and HF models/tokenizers once for all tests # Load both MLX and HF models/tokenizers once for all tests
cls.mlx_tokenizer, cls.mlx_model = load_mlx_model() cls.mlx_tokenizer, cls.mlx_model = load_mlx_model()
cls.hf_tokenizer, cls.hf_model = load_hf_model() cls.hf_tokenizer, cls.hf_model = load_hf_model()
def test_tokenizer(self): def test_tokenizer(self):
"""Verify MLX tokenizer matches Hugging Face tokenizer behavior.""" """Verify MLX tokenizer matches Hugging Face tokenizer behavior."""
self.assertEqual(len(self.mlx_tokenizer), len(self.hf_tokenizer)) self.assertEqual(len(self.mlx_tokenizer), len(self.hf_tokenizer))
sequences = [ sequences = [
"MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK", "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN" "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
] ]
# Compare batched tokenization (padded sequences) # Compare batched tokenization (padded sequences)
mlx_batch = self.mlx_tokenizer.batch_encode(sequences) mlx_batch = self.mlx_tokenizer.batch_encode(sequences)
hf_batch = self.hf_tokenizer(sequences, return_tensors="pt", padding=True)["input_ids"].cpu().numpy() hf_batch = (
self.hf_tokenizer(sequences, return_tensors="pt", padding=True)["input_ids"]
.cpu()
.numpy()
)
self.assertEqual(tuple(mlx_batch.shape), tuple(hf_batch.shape)) self.assertEqual(tuple(mlx_batch.shape), tuple(hf_batch.shape))
self.assertTrue( self.assertTrue(
np.array_equal(np.array(mlx_batch.tolist(), dtype=hf_batch.dtype), hf_batch) np.array_equal(np.array(mlx_batch.tolist(), dtype=hf_batch.dtype), hf_batch)
) )
# Compare single-sequence encode/decode # Compare single-sequence encode/decode
for sequence in sequences: for sequence in sequences:
mlx_tokens = self.mlx_tokenizer.encode(sequence) mlx_tokens = self.mlx_tokenizer.encode(sequence)
hf_tokens = self.hf_tokenizer(sequence, return_tensors="pt")["input_ids"].cpu().numpy().tolist()[0] hf_tokens = (
self.hf_tokenizer(sequence, return_tensors="pt")["input_ids"]
.cpu()
.numpy()
.tolist()[0]
)
self.assertTrue(np.array_equal(mlx_tokens, hf_tokens)) self.assertTrue(np.array_equal(mlx_tokens, hf_tokens))
self.assertEqual( self.assertEqual(
self.mlx_tokenizer.decode(mlx_tokens), self.mlx_tokenizer.decode(mlx_tokens),
self.hf_tokenizer.decode(hf_tokens).replace(" ", "") self.hf_tokenizer.decode(hf_tokens).replace(" ", ""),
) )
self.assertEqual( self.assertEqual(
self.mlx_tokenizer.decode(mlx_tokens, skip_special_tokens=True), self.mlx_tokenizer.decode(mlx_tokens, skip_special_tokens=True),
self.hf_tokenizer.decode(hf_tokens, skip_special_tokens=True).replace(" ", "") self.hf_tokenizer.decode(hf_tokens, skip_special_tokens=True).replace(
" ", ""
),
) )
def test_model(self): def test_model(self):
"""Verify MLX and HF model outputs match (logits, hidden states, attentions).""" """Verify MLX and HF model outputs match (logits, hidden states, attentions)."""
sequences = [ sequences = [
"MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK", "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN" "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
] ]
for sequence in sequences: for sequence in sequences:
# Tokenize # Tokenize
@@ -75,9 +88,9 @@ class TestESM2(unittest.TestCase):
# Forward pass # Forward pass
mlx_outputs = self.mlx_model( mlx_outputs = self.mlx_model(
mlx_tokens, mlx_tokens,
repr_layers=[self.mlx_model.num_layers], repr_layers=[self.mlx_model.num_layers],
need_head_weights=True need_head_weights=True,
) )
hf_outputs = self.hf_model(input_ids=hf_tokens) hf_outputs = self.hf_model(input_ids=hf_tokens)
@@ -90,12 +103,19 @@ class TestESM2(unittest.TestCase):
final_layer = self.mlx_model.num_layers final_layer = self.mlx_model.num_layers
mlx_hidden_states = np.array(mlx_outputs["representations"][final_layer]) mlx_hidden_states = np.array(mlx_outputs["representations"][final_layer])
hf_hidden_states = hf_outputs["hidden_states"][-1].detach().cpu().numpy() hf_hidden_states = hf_outputs["hidden_states"][-1].detach().cpu().numpy()
self.assertTrue(np.allclose(mlx_hidden_states, hf_hidden_states, rtol=1e-4, atol=1e-4)) self.assertTrue(
np.allclose(mlx_hidden_states, hf_hidden_states, rtol=1e-4, atol=1e-4)
)
# Compare attentions for final layer # Compare attentions for final layer
mlx_attentions = np.array(mlx_outputs["attentions"][:, final_layer-1, :, :, :]) mlx_attentions = np.array(
mlx_outputs["attentions"][:, final_layer - 1, :, :, :]
)
hf_attentions = hf_outputs["attentions"][-1].detach().cpu().numpy() hf_attentions = hf_outputs["attentions"][-1].detach().cpu().numpy()
self.assertTrue(np.allclose(mlx_attentions, hf_attentions, rtol=1e-4, atol=1e-4)) self.assertTrue(
np.allclose(mlx_attentions, hf_attentions, rtol=1e-4, atol=1e-4)
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()