mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Formatted code
This commit is contained in:
@@ -26,7 +26,7 @@ inputs = tokenizer(
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=1024
|
||||
max_length=1024,
|
||||
)
|
||||
input_ids = inputs["input_ids"].to("mps")
|
||||
attention_mask = inputs["attention_mask"].to("mps")
|
||||
|
||||
@@ -8,6 +8,7 @@ import mlx.core as mx
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
def download(hf_repo: str) -> Path:
|
||||
"""Download model from Hugging Face."""
|
||||
return Path(
|
||||
@@ -17,6 +18,7 @@ def download(hf_repo: str) -> Path:
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def remap_key(key: str) -> str:
|
||||
"""Remap HuggingFace ESM key names to MLX format."""
|
||||
|
||||
@@ -112,7 +114,9 @@ def convert(model_path: Path) -> Dict[str, mx.array]:
|
||||
for key, value in weights.items():
|
||||
mlx_key = remap_key(key)
|
||||
if mlx_key is not None:
|
||||
mlx_weights[mlx_key] = mx.array(value) if not isinstance(value, mx.array) else value
|
||||
mlx_weights[mlx_key] = (
|
||||
mx.array(value) if not isinstance(value, mx.array) else value
|
||||
)
|
||||
|
||||
# If lm_head.weight is missing but embed_tokens.weight exists, set up weight sharing
|
||||
# (This is for smaller models that don't have a separate lm_head.decoder.weight)
|
||||
@@ -125,19 +129,13 @@ def convert(model_path: Path) -> Dict[str, mx.array]:
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Convert ESM weights to MLX format")
|
||||
parser.add_argument(
|
||||
"--hf-path",
|
||||
default="facebook/esm2_t6_8M_UR50D",
|
||||
help="Hugging Face model path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlx-path",
|
||||
default=None,
|
||||
help="Output path for MLX model"
|
||||
"--hf-path", default="facebook/esm2_t6_8M_UR50D", help="Hugging Face model path"
|
||||
)
|
||||
parser.add_argument("--mlx-path", default=None, help="Output path for MLX model")
|
||||
parser.add_argument(
|
||||
"--checkpoints-dir",
|
||||
default="checkpoints",
|
||||
help="Directory to save checkpoints (default: checkpoints)"
|
||||
help="Directory to save checkpoints (default: checkpoints)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -174,5 +172,6 @@ def main():
|
||||
|
||||
print(f"Conversion complete! MLX model saved to {mlx_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -2,18 +2,18 @@
|
||||
ESM-2 protein language model implementation in MLX
|
||||
"""
|
||||
|
||||
from .model import ESM2
|
||||
from .tokenizer import ProteinTokenizer
|
||||
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
|
||||
from .attention import MultiheadAttention
|
||||
from .model import ESM2
|
||||
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
|
||||
from .rotary_embedding import RotaryEmbedding
|
||||
from .tokenizer import ProteinTokenizer
|
||||
|
||||
__all__ = [
|
||||
'ESM2',
|
||||
'ProteinTokenizer',
|
||||
'ContactPredictionHead',
|
||||
'RobertaLMHead',
|
||||
'TransformerLayer',
|
||||
'MultiheadAttention',
|
||||
'RotaryEmbedding'
|
||||
"ESM2",
|
||||
"ProteinTokenizer",
|
||||
"ContactPredictionHead",
|
||||
"RobertaLMHead",
|
||||
"TransformerLayer",
|
||||
"MultiheadAttention",
|
||||
"RotaryEmbedding",
|
||||
]
|
||||
|
||||
@@ -5,6 +5,7 @@ import mlx.nn as nn
|
||||
|
||||
from .rotary_embedding import RotaryEmbedding
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
"""
|
||||
Multi-head attention layer with rotary position embeddings, as used in ESM-2.
|
||||
@@ -114,7 +115,9 @@ class MultiheadAttention(nn.Module):
|
||||
attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len)
|
||||
# Convert key_padding_mask to boolean and expand dimensions
|
||||
# key_padding_mask: [bsz, src_len] -> [bsz, 1, 1, src_len]
|
||||
mask = mx.expand_dims(mx.expand_dims(key_padding_mask.astype(mx.bool_), 1), 2)
|
||||
mask = mx.expand_dims(
|
||||
mx.expand_dims(key_padding_mask.astype(mx.bool_), 1), 2
|
||||
)
|
||||
# Apply mask: set attention to -inf where mask is True (padded positions)
|
||||
attn_weights = mx.where(mask, -mx.inf, attn_weights)
|
||||
attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len)
|
||||
@@ -135,16 +138,16 @@ class MultiheadAttention(nn.Module):
|
||||
attn_weights_out: Optional[mx.array] = None
|
||||
if need_head_weights:
|
||||
# Return attention weights for each head separately
|
||||
attn_weights_out = attn_weights_float.reshape(
|
||||
bsz, self.num_heads, tgt_len, src_len
|
||||
).astype(attn.dtype).swapaxes(0, 1)
|
||||
attn_weights_out = (
|
||||
attn_weights_float.reshape(bsz, self.num_heads, tgt_len, src_len)
|
||||
.astype(attn.dtype)
|
||||
.swapaxes(0, 1)
|
||||
)
|
||||
else:
|
||||
# Return averaged attention weights
|
||||
attn_weights_out = mx.mean(
|
||||
attn_weights_float.reshape(bsz, self.num_heads, tgt_len, src_len),
|
||||
axis=1
|
||||
axis=1,
|
||||
).astype(attn.dtype)
|
||||
|
||||
return attn, attn_weights_out
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from pathlib import Path
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .tokenizer import ProteinTokenizer
|
||||
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
|
||||
from .tokenizer import ProteinTokenizer
|
||||
|
||||
|
||||
class ESM2(nn.Module):
|
||||
@@ -128,8 +128,12 @@ class ESM2(nn.Module):
|
||||
|
||||
mask_ratio_train = 0.15 * 0.8
|
||||
src_lengths = mx.sum(~padding_mask, axis=-1, keepdims=True)
|
||||
mask_ratio_observed = mx.sum(mask_positions, axis=-1, keepdims=True) / src_lengths
|
||||
scale_factor = (1 - mask_ratio_train) / mx.maximum(1 - mask_ratio_observed, 1e-8)
|
||||
mask_ratio_observed = (
|
||||
mx.sum(mask_positions, axis=-1, keepdims=True) / src_lengths
|
||||
)
|
||||
scale_factor = (1 - mask_ratio_train) / mx.maximum(
|
||||
1 - mask_ratio_observed, 1e-8
|
||||
)
|
||||
x = x * scale_factor[:, None, :]
|
||||
|
||||
# Zero out padding positions
|
||||
@@ -194,7 +198,9 @@ class ESM2(nn.Module):
|
||||
# Mask out padded positions if present
|
||||
if padding_mask is not None:
|
||||
attention_mask = 1 - padding_mask.astype(attentions.dtype)
|
||||
attention_mask = mx.expand_dims(attention_mask, 1) * mx.expand_dims(attention_mask, 2)
|
||||
attention_mask = mx.expand_dims(attention_mask, 1) * mx.expand_dims(
|
||||
attention_mask, 2
|
||||
)
|
||||
attentions = attentions * attention_mask[:, None, None, :, :]
|
||||
|
||||
result["attentions"] = attentions
|
||||
|
||||
@@ -58,7 +58,9 @@ class RotaryEmbedding(nn.Module):
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
|
||||
def _update_cos_sin_tables(self, x: mx.array, seq_dimension: int = 1) -> Tuple[mx.array, mx.array]:
|
||||
def _update_cos_sin_tables(
|
||||
self, x: mx.array, seq_dimension: int = 1
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""
|
||||
Compute and cache cos/sin tables for the given sequence length.
|
||||
|
||||
|
||||
@@ -1,14 +1,38 @@
|
||||
from typing import List, Sequence, Union, Optional
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
# Canonical amino-acid tokens (IUPAC standard + uncommon variants)
|
||||
PROTEIN_TOKENS = [
|
||||
"L", "A", "G", "V", "S", "E", "R", "T", "I", "D",
|
||||
"P", "K", "Q", "N", "F", "Y", "M", "H", "W", "C",
|
||||
"X", "B", "U", "Z", "O", ".", "-"
|
||||
"L",
|
||||
"A",
|
||||
"G",
|
||||
"V",
|
||||
"S",
|
||||
"E",
|
||||
"R",
|
||||
"T",
|
||||
"I",
|
||||
"D",
|
||||
"P",
|
||||
"K",
|
||||
"Q",
|
||||
"N",
|
||||
"F",
|
||||
"Y",
|
||||
"M",
|
||||
"H",
|
||||
"W",
|
||||
"C",
|
||||
"X",
|
||||
"B",
|
||||
"U",
|
||||
"Z",
|
||||
"O",
|
||||
".",
|
||||
"-",
|
||||
]
|
||||
|
||||
ArrayLike = Union[List[int], mx.array]
|
||||
@@ -201,7 +225,11 @@ class ProteinTokenizer:
|
||||
for i in ids_list:
|
||||
tok = self.id_to_token.get(i, "<unk>")
|
||||
if skip_special_tokens and tok in {
|
||||
"<cls>", "<pad>", "<eos>", "<unk>", "<mask>"
|
||||
"<cls>",
|
||||
"<pad>",
|
||||
"<eos>",
|
||||
"<unk>",
|
||||
"<mask>",
|
||||
}:
|
||||
continue
|
||||
toks.append(tok)
|
||||
|
||||
29
esm/main.py
29
esm/main.py
@@ -1,25 +1,27 @@
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from esm import ESM2
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ESM-2 MLX Inference")
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
default="checkpoints/mlx-esm2_t33_650M_UR50D",
|
||||
help="Path to MLX model checkpoint"
|
||||
help="Path to MLX model checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sequence",
|
||||
default="MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
|
||||
help="Protein sequence to test (default: human insulin)"
|
||||
help="Protein sequence to test (default: human insulin)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mask-position",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Position to mask (default: middle of sequence)"
|
||||
help="Position to mask (default: middle of sequence)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -28,7 +30,9 @@ def main():
|
||||
|
||||
# Determine sequence and position to mask
|
||||
sequence = args.sequence.upper()
|
||||
mask_pos = args.mask_position if args.mask_position is not None else len(sequence) // 2
|
||||
mask_pos = (
|
||||
args.mask_position if args.mask_position is not None else len(sequence) // 2
|
||||
)
|
||||
if mask_pos >= len(sequence):
|
||||
mask_pos = len(sequence) - 1
|
||||
original_aa = sequence[mask_pos] # The original amino acid at masked position
|
||||
@@ -40,15 +44,25 @@ def main():
|
||||
|
||||
# Tokenize sequence before and after the mask
|
||||
before = tokenizer.encode(sequence[:mask_pos], add_special_tokens=False)
|
||||
after = tokenizer.encode(sequence[mask_pos+1:], add_special_tokens=False)
|
||||
after = tokenizer.encode(sequence[mask_pos + 1 :], add_special_tokens=False)
|
||||
|
||||
# Build token sequence with <cls>, <mask>, and <eos>
|
||||
tokens = mx.array([[tokenizer.cls_id] + before.tolist() + [tokenizer.mask_id] + after.tolist() + [tokenizer.eos_id]])
|
||||
tokens = mx.array(
|
||||
[
|
||||
[tokenizer.cls_id]
|
||||
+ before.tolist()
|
||||
+ [tokenizer.mask_id]
|
||||
+ after.tolist()
|
||||
+ [tokenizer.eos_id]
|
||||
]
|
||||
)
|
||||
mask_token_pos = 1 + len(before) # Position of <mask> token
|
||||
|
||||
# Run model to get logits for each token position
|
||||
logits = model(tokens)["logits"]
|
||||
probs = mx.softmax(logits[0, mask_token_pos, :]) # Softmax over vocabulary at mask position
|
||||
probs = mx.softmax(
|
||||
logits[0, mask_token_pos, :]
|
||||
) # Softmax over vocabulary at mask position
|
||||
|
||||
# Get top-5 most likely tokens
|
||||
top_indices = mx.argsort(probs)[-5:][::-1]
|
||||
@@ -62,5 +76,6 @@ def main():
|
||||
marker = "✓" if token == original_aa else " "
|
||||
print(f"{marker} {i+1}. {token}: {prob:.3f} ({prob*100:.1f}%)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
48
esm/test.py
48
esm/test.py
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer, EsmForMaskedLM, EsmConfig
|
||||
from transformers import AutoTokenizer, EsmConfig, EsmForMaskedLM
|
||||
|
||||
from esm import ESM2
|
||||
|
||||
@@ -8,22 +9,23 @@ from esm import ESM2
|
||||
MLX_PATH = "checkpoints/mlx-esm2_t12_35M_UR50D"
|
||||
HF_PATH = "facebook/esm2_t12_35M_UR50D"
|
||||
|
||||
|
||||
def load_mlx_model():
|
||||
"""Load MLX ESM-2 model and tokenizer."""
|
||||
tokenizer, model = ESM2.from_pretrained(MLX_PATH)
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
def load_hf_model():
|
||||
"""Load Hugging Face ESM-2 model and tokenizer with hidden states + attentions."""
|
||||
tokenizer = AutoTokenizer.from_pretrained(HF_PATH)
|
||||
config = EsmConfig.from_pretrained(
|
||||
HF_PATH,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True
|
||||
HF_PATH, output_hidden_states=True, output_attentions=True
|
||||
)
|
||||
model = EsmForMaskedLM.from_pretrained(HF_PATH, config=config)
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
class TestESM2(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@@ -37,12 +39,16 @@ class TestESM2(unittest.TestCase):
|
||||
|
||||
sequences = [
|
||||
"MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
|
||||
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"
|
||||
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
|
||||
]
|
||||
|
||||
# Compare batched tokenization (padded sequences)
|
||||
mlx_batch = self.mlx_tokenizer.batch_encode(sequences)
|
||||
hf_batch = self.hf_tokenizer(sequences, return_tensors="pt", padding=True)["input_ids"].cpu().numpy()
|
||||
hf_batch = (
|
||||
self.hf_tokenizer(sequences, return_tensors="pt", padding=True)["input_ids"]
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
self.assertEqual(tuple(mlx_batch.shape), tuple(hf_batch.shape))
|
||||
self.assertTrue(
|
||||
np.array_equal(np.array(mlx_batch.tolist(), dtype=hf_batch.dtype), hf_batch)
|
||||
@@ -51,22 +57,29 @@ class TestESM2(unittest.TestCase):
|
||||
# Compare single-sequence encode/decode
|
||||
for sequence in sequences:
|
||||
mlx_tokens = self.mlx_tokenizer.encode(sequence)
|
||||
hf_tokens = self.hf_tokenizer(sequence, return_tensors="pt")["input_ids"].cpu().numpy().tolist()[0]
|
||||
hf_tokens = (
|
||||
self.hf_tokenizer(sequence, return_tensors="pt")["input_ids"]
|
||||
.cpu()
|
||||
.numpy()
|
||||
.tolist()[0]
|
||||
)
|
||||
self.assertTrue(np.array_equal(mlx_tokens, hf_tokens))
|
||||
self.assertEqual(
|
||||
self.mlx_tokenizer.decode(mlx_tokens),
|
||||
self.hf_tokenizer.decode(hf_tokens).replace(" ", "")
|
||||
self.hf_tokenizer.decode(hf_tokens).replace(" ", ""),
|
||||
)
|
||||
self.assertEqual(
|
||||
self.mlx_tokenizer.decode(mlx_tokens, skip_special_tokens=True),
|
||||
self.hf_tokenizer.decode(hf_tokens, skip_special_tokens=True).replace(" ", "")
|
||||
self.hf_tokenizer.decode(hf_tokens, skip_special_tokens=True).replace(
|
||||
" ", ""
|
||||
),
|
||||
)
|
||||
|
||||
def test_model(self):
|
||||
"""Verify MLX and HF model outputs match (logits, hidden states, attentions)."""
|
||||
sequences = [
|
||||
"MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
|
||||
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"
|
||||
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
|
||||
]
|
||||
for sequence in sequences:
|
||||
# Tokenize
|
||||
@@ -77,7 +90,7 @@ class TestESM2(unittest.TestCase):
|
||||
mlx_outputs = self.mlx_model(
|
||||
mlx_tokens,
|
||||
repr_layers=[self.mlx_model.num_layers],
|
||||
need_head_weights=True
|
||||
need_head_weights=True,
|
||||
)
|
||||
hf_outputs = self.hf_model(input_ids=hf_tokens)
|
||||
|
||||
@@ -90,12 +103,19 @@ class TestESM2(unittest.TestCase):
|
||||
final_layer = self.mlx_model.num_layers
|
||||
mlx_hidden_states = np.array(mlx_outputs["representations"][final_layer])
|
||||
hf_hidden_states = hf_outputs["hidden_states"][-1].detach().cpu().numpy()
|
||||
self.assertTrue(np.allclose(mlx_hidden_states, hf_hidden_states, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(
|
||||
np.allclose(mlx_hidden_states, hf_hidden_states, rtol=1e-4, atol=1e-4)
|
||||
)
|
||||
|
||||
# Compare attentions for final layer
|
||||
mlx_attentions = np.array(mlx_outputs["attentions"][:, final_layer-1, :, :, :])
|
||||
mlx_attentions = np.array(
|
||||
mlx_outputs["attentions"][:, final_layer - 1, :, :, :]
|
||||
)
|
||||
hf_attentions = hf_outputs["attentions"][-1].detach().cpu().numpy()
|
||||
self.assertTrue(np.allclose(mlx_attentions, hf_attentions, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(
|
||||
np.allclose(mlx_attentions, hf_attentions, rtol=1e-4, atol=1e-4)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user