Formatted code

This commit is contained in:
Vincent Amato
2025-08-15 23:54:53 -04:00
parent eccabdd227
commit 98d800866e
10 changed files with 200 additions and 127 deletions

View File

@@ -44,4 +44,4 @@ throughput = batch_size * 1000 / ms_per_step
# Display results
print(f"Time (ms) per step: {ms_per_step:.3f}")
print(f"Throughput: {throughput:.2f} sequences/sec")
print(f"Throughput: {throughput:.2f} sequences/sec")

View File

@@ -22,11 +22,11 @@ steps = 50
# Tokenize input sequence and replicate for the batch
# Replicate the same sequence 'batch_size' times to create a batch
inputs = tokenizer(
[protein_sequence] * batch_size,
return_tensors="pt",
padding=True,
[protein_sequence] * batch_size,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024
max_length=1024,
)
input_ids = inputs["input_ids"].to("mps")
attention_mask = inputs["attention_mask"].to("mps")
@@ -49,4 +49,4 @@ throughput = batch_size * 1000 / ms_per_step
# Report results
print(f"Time (ms) per step: {ms_per_step:.3f}")
print(f"Throughput: {throughput:.2f} sequences/sec")
print(f"Throughput: {throughput:.2f} sequences/sec")

View File

@@ -8,6 +8,7 @@ import mlx.core as mx
import torch
from huggingface_hub import snapshot_download
def download(hf_repo: str) -> Path:
"""Download model from Hugging Face."""
return Path(
@@ -17,13 +18,14 @@ def download(hf_repo: str) -> Path:
)
)
def remap_key(key: str) -> str:
"""Remap HuggingFace ESM key names to MLX format."""
# Skip position embeddings and position_ids
if "position_embeddings" in key or "position_ids" in key:
return None
# Map lm_head components properly
if key == "lm_head.decoder.weight":
return "lm_head.weight"
@@ -37,14 +39,14 @@ def remap_key(key: str) -> str:
return "lm_head.layer_norm.weight"
if key == "lm_head.layer_norm.bias":
return "lm_head.layer_norm.bias"
# Core remapping patterns
key = key.replace("esm.embeddings.word_embeddings", "embed_tokens")
key = key.replace("esm.encoder.emb_layer_norm_after", "emb_layer_norm_after")
key = key.replace("esm.encoder.layer.", "layer_")
key = key.replace("esm.contact_head", "contact_head")
key = key.replace("lm_head", "lm_head")
# Attention patterns
key = key.replace(".attention.self.", ".self_attn.")
key = key.replace(".attention.output.dense", ".self_attn.out_proj")
@@ -53,12 +55,12 @@ def remap_key(key: str) -> str:
key = key.replace(".key", ".k_proj")
key = key.replace(".value", ".v_proj")
key = key.replace(".rotary_embeddings", ".rot_emb")
# FFN patterns
key = key.replace(".intermediate.dense", ".fc1")
key = key.replace(".output.dense", ".fc2")
key = key.replace(".LayerNorm", ".final_layer_norm")
return key
@@ -70,24 +72,24 @@ def load_weights(model_path: Path) -> Dict:
if safetensors_path.exists():
print("Loading from safetensors...")
return mx.load(str(safetensors_path))
# Check for single bin file
single_bin_path = model_path / "pytorch_model.bin"
if single_bin_path.exists():
print("Loading from pytorch_model.bin...")
state_dict = torch.load(str(single_bin_path), map_location="cpu")
return {k: v.numpy() for k, v in state_dict.items()}
# Check for sharded bin files
index_file = model_path / "pytorch_model.bin.index.json"
if index_file.exists():
print("Loading from sharded bin files...")
with open(index_file) as f:
index = json.load(f)
# Get unique shard files
shard_files = set(index["weight_map"].values())
# Load all shards
state_dict = {}
for shard_file in sorted(shard_files):
@@ -95,9 +97,9 @@ def load_weights(model_path: Path) -> Dict:
shard_path = model_path / shard_file
shard_dict = torch.load(str(shard_path), map_location="cpu")
state_dict.update(shard_dict)
return {k: v.numpy() for k, v in state_dict.items()}
raise ValueError(f"No model weights found in {model_path}")
@@ -106,38 +108,34 @@ def convert(model_path: Path) -> Dict[str, mx.array]:
# Load weights from any format
weights = load_weights(model_path)
# Convert keys and create MLX arrays
mlx_weights = {}
for key, value in weights.items():
mlx_key = remap_key(key)
if mlx_key is not None:
mlx_weights[mlx_key] = mx.array(value) if not isinstance(value, mx.array) else value
mlx_weights[mlx_key] = (
mx.array(value) if not isinstance(value, mx.array) else value
)
# If lm_head.weight is missing but embed_tokens.weight exists, set up weight sharing
# (This is for smaller models that don't have a separate lm_head.decoder.weight)
if "lm_head.weight" not in mlx_weights and "embed_tokens.weight" in mlx_weights:
mlx_weights["lm_head.weight"] = mlx_weights["embed_tokens.weight"]
return mlx_weights
def main():
parser = argparse.ArgumentParser(description="Convert ESM weights to MLX format")
parser.add_argument(
"--hf-path",
default="facebook/esm2_t6_8M_UR50D",
help="Hugging Face model path"
)
parser.add_argument(
"--mlx-path",
default=None,
help="Output path for MLX model"
"--hf-path", default="facebook/esm2_t6_8M_UR50D", help="Hugging Face model path"
)
parser.add_argument("--mlx-path", default=None, help="Output path for MLX model")
parser.add_argument(
"--checkpoints-dir",
default="checkpoints",
help="Directory to save checkpoints (default: checkpoints)"
help="Directory to save checkpoints (default: checkpoints)",
)
args = parser.parse_args()
@@ -145,7 +143,7 @@ def main():
# Download model
print(f"Downloading {args.hf_path}...")
model_path = download(args.hf_path)
# Set output path
if args.mlx_path is None:
model_name = args.hf_path.split("/")[-1]
@@ -154,25 +152,26 @@ def main():
args.mlx_path = checkpoints_dir / f"mlx-{model_name}"
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
# Convert weights
print("Converting weights...")
mlx_weights = convert(model_path)
# Save weights
print(f"Saving MLX weights to {mlx_path}...")
mx.save_safetensors(str(mlx_path / "model.safetensors"), mlx_weights)
# Copy config and other files
print("Copying config...")
shutil.copy(model_path / "config.json", mlx_path / "config.json")
for file_name in ["special_tokens_map.json", "tokenizer.json", "vocab.txt"]:
src_file = model_path / file_name
if src_file.exists():
shutil.copy(src_file, mlx_path / file_name)
print(f"Conversion complete! MLX model saved to {mlx_path}")
if __name__ == "__main__":
main()
main()

View File

@@ -2,18 +2,18 @@
ESM-2 protein language model implementation in MLX
"""
from .model import ESM2
from .tokenizer import ProteinTokenizer
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
from .attention import MultiheadAttention
from .model import ESM2
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
from .rotary_embedding import RotaryEmbedding
from .tokenizer import ProteinTokenizer
__all__ = [
'ESM2',
'ProteinTokenizer',
'ContactPredictionHead',
'RobertaLMHead',
'TransformerLayer',
'MultiheadAttention',
'RotaryEmbedding'
"ESM2",
"ProteinTokenizer",
"ContactPredictionHead",
"RobertaLMHead",
"TransformerLayer",
"MultiheadAttention",
"RotaryEmbedding",
]

View File

@@ -5,14 +5,15 @@ import mlx.nn as nn
from .rotary_embedding import RotaryEmbedding
class MultiheadAttention(nn.Module):
"""
Multi-head attention layer with rotary position embeddings, as used in ESM-2.
This module implements both self-attention (when `key` and `value` are not
provided) and cross-attention. It projects input sequences into queries,
keys, and values, applies rotary position embeddings to encode relative
position information, computes scaled dot-product attention over multiple
This module implements both self-attention (when `key` and `value` are not
provided) and cross-attention. It projects input sequences into queries,
keys, and values, applies rotary position embeddings to encode relative
position information, computes scaled dot-product attention over multiple
heads in parallel, and returns a combined output projection.
Args:
@@ -56,7 +57,7 @@ class MultiheadAttention(nn.Module):
) -> Tuple[mx.array, Optional[mx.array]]:
"""
Multi-head attention forward pass.
Args:
query: Tensor of shape (tgt_len, batch, embed_dim).
key: Optional tensor of shape (src_len, batch, embed_dim). Defaults to `query`.
@@ -67,32 +68,32 @@ class MultiheadAttention(nn.Module):
Returns:
attn_output: Tensor of shape (tgt_len, batch, embed_dim).
attn_weights_out: Attention weights of shape
attn_weights_out: Attention weights of shape
(num_heads, batch, tgt_len, src_len) if per-head,
or (batch, tgt_len, src_len) if averaged.
"""
tgt_len, bsz, embed_dim = query.shape
assert embed_dim == self.embed_dim
# For self-attention, use query as key and value if not provided
if key is None:
key = query
if value is None:
value = query
# Project queries, keys, values
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q = q * self.scaling
# Reshape for multi-head attention
q = q.reshape(tgt_len, bsz * self.num_heads, self.head_dim).swapaxes(0, 1)
k = k.reshape(-1, bsz * self.num_heads, self.head_dim).swapaxes(0, 1)
v = v.reshape(-1, bsz * self.num_heads, self.head_dim).swapaxes(0, 1)
src_len = k.shape[1]
# Apply rotary embeddings if present
@@ -114,7 +115,9 @@ class MultiheadAttention(nn.Module):
attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len)
# Convert key_padding_mask to boolean and expand dimensions
# key_padding_mask: [bsz, src_len] -> [bsz, 1, 1, src_len]
mask = mx.expand_dims(mx.expand_dims(key_padding_mask.astype(mx.bool_), 1), 2)
mask = mx.expand_dims(
mx.expand_dims(key_padding_mask.astype(mx.bool_), 1), 2
)
# Apply mask: set attention to -inf where mask is True (padded positions)
attn_weights = mx.where(mask, -mx.inf, attn_weights)
attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len)
@@ -126,25 +129,25 @@ class MultiheadAttention(nn.Module):
# Compute attention output
attn = attn_probs @ v
assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
# Reshape output
attn = attn.swapaxes(0, 1).reshape(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
# Return attention weights if requested
attn_weights_out: Optional[mx.array] = None
if need_head_weights:
# Return attention weights for each head separately
attn_weights_out = attn_weights_float.reshape(
bsz, self.num_heads, tgt_len, src_len
).astype(attn.dtype).swapaxes(0, 1)
attn_weights_out = (
attn_weights_float.reshape(bsz, self.num_heads, tgt_len, src_len)
.astype(attn.dtype)
.swapaxes(0, 1)
)
else:
# Return averaged attention weights
attn_weights_out = mx.mean(
attn_weights_float.reshape(bsz, self.num_heads, tgt_len, src_len),
axis=1
axis=1,
).astype(attn.dtype)
return attn, attn_weights_out

View File

@@ -1,12 +1,12 @@
from typing import List, Dict, Optional, Tuple
from pathlib import Path
import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .tokenizer import ProteinTokenizer
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
from .tokenizer import ProteinTokenizer
class ESM2(nn.Module):
@@ -40,7 +40,7 @@ class ESM2(nn.Module):
self.tokenizer = tokenizer
self.vocab_size = len(tokenizer)
# Special token IDs / config
# Special token IDs / config
self.padding_idx = tokenizer.pad_id
self.mask_idx = tokenizer.mask_id
self.cls_idx = tokenizer.cls_id
@@ -128,8 +128,12 @@ class ESM2(nn.Module):
mask_ratio_train = 0.15 * 0.8
src_lengths = mx.sum(~padding_mask, axis=-1, keepdims=True)
mask_ratio_observed = mx.sum(mask_positions, axis=-1, keepdims=True) / src_lengths
scale_factor = (1 - mask_ratio_train) / mx.maximum(1 - mask_ratio_observed, 1e-8)
mask_ratio_observed = (
mx.sum(mask_positions, axis=-1, keepdims=True) / src_lengths
)
scale_factor = (1 - mask_ratio_train) / mx.maximum(
1 - mask_ratio_observed, 1e-8
)
x = x * scale_factor[:, None, :]
# Zero out padding positions
@@ -194,7 +198,9 @@ class ESM2(nn.Module):
# Mask out padded positions if present
if padding_mask is not None:
attention_mask = 1 - padding_mask.astype(attentions.dtype)
attention_mask = mx.expand_dims(attention_mask, 1) * mx.expand_dims(attention_mask, 2)
attention_mask = mx.expand_dims(attention_mask, 1) * mx.expand_dims(
attention_mask, 2
)
attentions = attentions * attention_mask[:, None, None, :, :]
result["attentions"] = attentions
@@ -304,7 +310,7 @@ class ESM2(nn.Module):
# Check for vocab and special tokens files
vocab_path = model_dir / "vocab.txt"
special_tokens_path = model_dir / "special_tokens_map.json"
if vocab_path.exists() and special_tokens_path.exists():
tokenizer = ProteinTokenizer(
vocab_file=str(vocab_path),
@@ -333,4 +339,4 @@ class ESM2(nn.Module):
cur[parts[-1]] = value
model.update(nested_weights)
return tokenizer, model
return tokenizer, model

View File

@@ -58,7 +58,9 @@ class RotaryEmbedding(nn.Module):
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_tables(self, x: mx.array, seq_dimension: int = 1) -> Tuple[mx.array, mx.array]:
def _update_cos_sin_tables(
self, x: mx.array, seq_dimension: int = 1
) -> Tuple[mx.array, mx.array]:
"""
Compute and cache cos/sin tables for the given sequence length.
@@ -109,4 +111,4 @@ class RotaryEmbedding(nn.Module):
return (
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
)
)

View File

@@ -1,14 +1,38 @@
from typing import List, Sequence, Union, Optional
import json
from pathlib import Path
from typing import List, Optional, Sequence, Union
import mlx.core as mx
# Canonical amino-acid tokens (IUPAC standard + uncommon variants)
PROTEIN_TOKENS = [
"L", "A", "G", "V", "S", "E", "R", "T", "I", "D",
"P", "K", "Q", "N", "F", "Y", "M", "H", "W", "C",
"X", "B", "U", "Z", "O", ".", "-"
"L",
"A",
"G",
"V",
"S",
"E",
"R",
"T",
"I",
"D",
"P",
"K",
"Q",
"N",
"F",
"Y",
"M",
"H",
"W",
"C",
"X",
"B",
"U",
"Z",
"O",
".",
"-",
]
ArrayLike = Union[List[int], mx.array]
@@ -39,7 +63,7 @@ class ProteinTokenizer:
If both files are provided, they override the default vocabulary and
special token mappings. Otherwise, defaults are loaded.
"""
# Load vocabulary from files if given, otherwise use built-in defaults
if vocab_file and special_tokens_map_file:
self._load_from_files(vocab_file, special_tokens_map_file)
@@ -51,15 +75,15 @@ class ProteinTokenizer:
self.id_to_token = {i: tok for i, tok in enumerate(self.vocab)}
# Cache special token IDs
self.cls_id = self.token_to_id["<cls>"]
self.pad_id = self.token_to_id["<pad>"]
self.eos_id = self.token_to_id["<eos>"]
self.unk_id = self.token_to_id["<unk>"]
self.cls_id = self.token_to_id["<cls>"]
self.pad_id = self.token_to_id["<pad>"]
self.eos_id = self.token_to_id["<eos>"]
self.unk_id = self.token_to_id["<unk>"]
self.mask_id = self.token_to_id["<mask>"]
# Behavior flags for ESM-2-style BOS/EOS
self.prepend_bos = True
self.append_eos = True
self.append_eos = True
def _load_from_files(self, vocab_file: str, special_tokens_map_file: str) -> None:
"""Load vocabulary and special tokens from the provided files."""
@@ -201,7 +225,11 @@ class ProteinTokenizer:
for i in ids_list:
tok = self.id_to_token.get(i, "<unk>")
if skip_special_tokens and tok in {
"<cls>", "<pad>", "<eos>", "<unk>", "<mask>"
"<cls>",
"<pad>",
"<eos>",
"<unk>",
"<mask>",
}:
continue
toks.append(tok)

View File

@@ -1,58 +1,72 @@
import argparse
import mlx.core as mx
from esm import ESM2
def main():
parser = argparse.ArgumentParser(description="ESM-2 MLX Inference")
parser.add_argument(
"--model-path",
"--model-path",
default="checkpoints/mlx-esm2_t33_650M_UR50D",
help="Path to MLX model checkpoint"
help="Path to MLX model checkpoint",
)
parser.add_argument(
"--sequence",
"--sequence",
default="MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
help="Protein sequence to test (default: human insulin)"
help="Protein sequence to test (default: human insulin)",
)
parser.add_argument(
"--mask-position",
type=int,
"--mask-position",
type=int,
default=None,
help="Position to mask (default: middle of sequence)"
help="Position to mask (default: middle of sequence)",
)
args = parser.parse_args()
# Load pretrained ESM-2 model and tokenizer
tokenizer, model = ESM2.from_pretrained(args.model_path)
# Determine sequence and position to mask
sequence = args.sequence.upper()
mask_pos = args.mask_position if args.mask_position is not None else len(sequence) // 2
mask_pos = (
args.mask_position if args.mask_position is not None else len(sequence) // 2
)
if mask_pos >= len(sequence):
mask_pos = len(sequence) - 1
original_aa = sequence[mask_pos] # The original amino acid at masked position
# Display input info
print(f"Original sequence: {sequence}")
print(f"Masked sequence: {sequence[:mask_pos]}<mask>{sequence[mask_pos+1:]}")
print(f"Predicting position {mask_pos}: {original_aa}\n")
# Tokenize sequence before and after the mask
before = tokenizer.encode(sequence[:mask_pos], add_special_tokens=False)
after = tokenizer.encode(sequence[mask_pos+1:], add_special_tokens=False)
after = tokenizer.encode(sequence[mask_pos + 1 :], add_special_tokens=False)
# Build token sequence with <cls>, <mask>, and <eos>
tokens = mx.array([[tokenizer.cls_id] + before.tolist() + [tokenizer.mask_id] + after.tolist() + [tokenizer.eos_id]])
tokens = mx.array(
[
[tokenizer.cls_id]
+ before.tolist()
+ [tokenizer.mask_id]
+ after.tolist()
+ [tokenizer.eos_id]
]
)
mask_token_pos = 1 + len(before) # Position of <mask> token
# Run model to get logits for each token position
logits = model(tokens)["logits"]
probs = mx.softmax(logits[0, mask_token_pos, :]) # Softmax over vocabulary at mask position
probs = mx.softmax(
logits[0, mask_token_pos, :]
) # Softmax over vocabulary at mask position
# Get top-5 most likely tokens
top_indices = mx.argsort(probs)[-5:][::-1]
# Print predictions
print("Top predictions:")
for i, idx in enumerate(top_indices):
@@ -62,5 +76,6 @@ def main():
marker = "" if token == original_aa else " "
print(f"{marker} {i+1}. {token}: {prob:.3f} ({prob*100:.1f}%)")
if __name__ == "__main__":
main()
main()

View File

@@ -1,6 +1,7 @@
import unittest
import numpy as np
from transformers import AutoTokenizer, EsmForMaskedLM, EsmConfig
from transformers import AutoTokenizer, EsmConfig, EsmForMaskedLM
from esm import ESM2
@@ -8,65 +9,77 @@ from esm import ESM2
MLX_PATH = "checkpoints/mlx-esm2_t12_35M_UR50D"
HF_PATH = "facebook/esm2_t12_35M_UR50D"
def load_mlx_model():
"""Load MLX ESM-2 model and tokenizer."""
tokenizer, model = ESM2.from_pretrained(MLX_PATH)
return tokenizer, model
def load_hf_model():
"""Load Hugging Face ESM-2 model and tokenizer with hidden states + attentions."""
tokenizer = AutoTokenizer.from_pretrained(HF_PATH)
config = EsmConfig.from_pretrained(
HF_PATH,
output_hidden_states=True,
output_attentions=True
HF_PATH, output_hidden_states=True, output_attentions=True
)
model = EsmForMaskedLM.from_pretrained(HF_PATH, config=config)
return tokenizer, model
class TestESM2(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Load both MLX and HF models/tokenizers once for all tests
cls.mlx_tokenizer, cls.mlx_model = load_mlx_model()
cls.hf_tokenizer, cls.hf_model = load_hf_model()
def test_tokenizer(self):
"""Verify MLX tokenizer matches Hugging Face tokenizer behavior."""
self.assertEqual(len(self.mlx_tokenizer), len(self.hf_tokenizer))
sequences = [
"MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
]
# Compare batched tokenization (padded sequences)
mlx_batch = self.mlx_tokenizer.batch_encode(sequences)
hf_batch = self.hf_tokenizer(sequences, return_tensors="pt", padding=True)["input_ids"].cpu().numpy()
hf_batch = (
self.hf_tokenizer(sequences, return_tensors="pt", padding=True)["input_ids"]
.cpu()
.numpy()
)
self.assertEqual(tuple(mlx_batch.shape), tuple(hf_batch.shape))
self.assertTrue(
np.array_equal(np.array(mlx_batch.tolist(), dtype=hf_batch.dtype), hf_batch)
)
# Compare single-sequence encode/decode
for sequence in sequences:
mlx_tokens = self.mlx_tokenizer.encode(sequence)
hf_tokens = self.hf_tokenizer(sequence, return_tensors="pt")["input_ids"].cpu().numpy().tolist()[0]
hf_tokens = (
self.hf_tokenizer(sequence, return_tensors="pt")["input_ids"]
.cpu()
.numpy()
.tolist()[0]
)
self.assertTrue(np.array_equal(mlx_tokens, hf_tokens))
self.assertEqual(
self.mlx_tokenizer.decode(mlx_tokens),
self.hf_tokenizer.decode(hf_tokens).replace(" ", "")
self.hf_tokenizer.decode(hf_tokens).replace(" ", ""),
)
self.assertEqual(
self.mlx_tokenizer.decode(mlx_tokens, skip_special_tokens=True),
self.hf_tokenizer.decode(hf_tokens, skip_special_tokens=True).replace(" ", "")
self.hf_tokenizer.decode(hf_tokens, skip_special_tokens=True).replace(
" ", ""
),
)
def test_model(self):
"""Verify MLX and HF model outputs match (logits, hidden states, attentions)."""
sequences = [
"MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
]
for sequence in sequences:
# Tokenize
@@ -75,9 +88,9 @@ class TestESM2(unittest.TestCase):
# Forward pass
mlx_outputs = self.mlx_model(
mlx_tokens,
repr_layers=[self.mlx_model.num_layers],
need_head_weights=True
mlx_tokens,
repr_layers=[self.mlx_model.num_layers],
need_head_weights=True,
)
hf_outputs = self.hf_model(input_ids=hf_tokens)
@@ -90,12 +103,19 @@ class TestESM2(unittest.TestCase):
final_layer = self.mlx_model.num_layers
mlx_hidden_states = np.array(mlx_outputs["representations"][final_layer])
hf_hidden_states = hf_outputs["hidden_states"][-1].detach().cpu().numpy()
self.assertTrue(np.allclose(mlx_hidden_states, hf_hidden_states, rtol=1e-4, atol=1e-4))
self.assertTrue(
np.allclose(mlx_hidden_states, hf_hidden_states, rtol=1e-4, atol=1e-4)
)
# Compare attentions for final layer
mlx_attentions = np.array(mlx_outputs["attentions"][:, final_layer-1, :, :, :])
mlx_attentions = np.array(
mlx_outputs["attentions"][:, final_layer - 1, :, :, :]
)
hf_attentions = hf_outputs["attentions"][-1].detach().cpu().numpy()
self.assertTrue(np.allclose(mlx_attentions, hf_attentions, rtol=1e-4, atol=1e-4))
self.assertTrue(
np.allclose(mlx_attentions, hf_attentions, rtol=1e-4, atol=1e-4)
)
if __name__ == "__main__":
unittest.main()