mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Formatted code
This commit is contained in:
@@ -26,7 +26,7 @@ inputs = tokenizer(
|
|||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=1024
|
max_length=1024,
|
||||||
)
|
)
|
||||||
input_ids = inputs["input_ids"].to("mps")
|
input_ids = inputs["input_ids"].to("mps")
|
||||||
attention_mask = inputs["attention_mask"].to("mps")
|
attention_mask = inputs["attention_mask"].to("mps")
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import mlx.core as mx
|
|||||||
import torch
|
import torch
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
|
||||||
def download(hf_repo: str) -> Path:
|
def download(hf_repo: str) -> Path:
|
||||||
"""Download model from Hugging Face."""
|
"""Download model from Hugging Face."""
|
||||||
return Path(
|
return Path(
|
||||||
@@ -17,6 +18,7 @@ def download(hf_repo: str) -> Path:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def remap_key(key: str) -> str:
|
def remap_key(key: str) -> str:
|
||||||
"""Remap HuggingFace ESM key names to MLX format."""
|
"""Remap HuggingFace ESM key names to MLX format."""
|
||||||
|
|
||||||
@@ -112,7 +114,9 @@ def convert(model_path: Path) -> Dict[str, mx.array]:
|
|||||||
for key, value in weights.items():
|
for key, value in weights.items():
|
||||||
mlx_key = remap_key(key)
|
mlx_key = remap_key(key)
|
||||||
if mlx_key is not None:
|
if mlx_key is not None:
|
||||||
mlx_weights[mlx_key] = mx.array(value) if not isinstance(value, mx.array) else value
|
mlx_weights[mlx_key] = (
|
||||||
|
mx.array(value) if not isinstance(value, mx.array) else value
|
||||||
|
)
|
||||||
|
|
||||||
# If lm_head.weight is missing but embed_tokens.weight exists, set up weight sharing
|
# If lm_head.weight is missing but embed_tokens.weight exists, set up weight sharing
|
||||||
# (This is for smaller models that don't have a separate lm_head.decoder.weight)
|
# (This is for smaller models that don't have a separate lm_head.decoder.weight)
|
||||||
@@ -125,19 +129,13 @@ def convert(model_path: Path) -> Dict[str, mx.array]:
|
|||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Convert ESM weights to MLX format")
|
parser = argparse.ArgumentParser(description="Convert ESM weights to MLX format")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hf-path",
|
"--hf-path", default="facebook/esm2_t6_8M_UR50D", help="Hugging Face model path"
|
||||||
default="facebook/esm2_t6_8M_UR50D",
|
|
||||||
help="Hugging Face model path"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--mlx-path",
|
|
||||||
default=None,
|
|
||||||
help="Output path for MLX model"
|
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--mlx-path", default=None, help="Output path for MLX model")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--checkpoints-dir",
|
"--checkpoints-dir",
|
||||||
default="checkpoints",
|
default="checkpoints",
|
||||||
help="Directory to save checkpoints (default: checkpoints)"
|
help="Directory to save checkpoints (default: checkpoints)",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -174,5 +172,6 @@ def main():
|
|||||||
|
|
||||||
print(f"Conversion complete! MLX model saved to {mlx_path}")
|
print(f"Conversion complete! MLX model saved to {mlx_path}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
@@ -2,18 +2,18 @@
|
|||||||
ESM-2 protein language model implementation in MLX
|
ESM-2 protein language model implementation in MLX
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .model import ESM2
|
|
||||||
from .tokenizer import ProteinTokenizer
|
|
||||||
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
|
|
||||||
from .attention import MultiheadAttention
|
from .attention import MultiheadAttention
|
||||||
|
from .model import ESM2
|
||||||
|
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
|
||||||
from .rotary_embedding import RotaryEmbedding
|
from .rotary_embedding import RotaryEmbedding
|
||||||
|
from .tokenizer import ProteinTokenizer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ESM2',
|
"ESM2",
|
||||||
'ProteinTokenizer',
|
"ProteinTokenizer",
|
||||||
'ContactPredictionHead',
|
"ContactPredictionHead",
|
||||||
'RobertaLMHead',
|
"RobertaLMHead",
|
||||||
'TransformerLayer',
|
"TransformerLayer",
|
||||||
'MultiheadAttention',
|
"MultiheadAttention",
|
||||||
'RotaryEmbedding'
|
"RotaryEmbedding",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import mlx.nn as nn
|
|||||||
|
|
||||||
from .rotary_embedding import RotaryEmbedding
|
from .rotary_embedding import RotaryEmbedding
|
||||||
|
|
||||||
|
|
||||||
class MultiheadAttention(nn.Module):
|
class MultiheadAttention(nn.Module):
|
||||||
"""
|
"""
|
||||||
Multi-head attention layer with rotary position embeddings, as used in ESM-2.
|
Multi-head attention layer with rotary position embeddings, as used in ESM-2.
|
||||||
@@ -114,7 +115,9 @@ class MultiheadAttention(nn.Module):
|
|||||||
attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len)
|
||||||
# Convert key_padding_mask to boolean and expand dimensions
|
# Convert key_padding_mask to boolean and expand dimensions
|
||||||
# key_padding_mask: [bsz, src_len] -> [bsz, 1, 1, src_len]
|
# key_padding_mask: [bsz, src_len] -> [bsz, 1, 1, src_len]
|
||||||
mask = mx.expand_dims(mx.expand_dims(key_padding_mask.astype(mx.bool_), 1), 2)
|
mask = mx.expand_dims(
|
||||||
|
mx.expand_dims(key_padding_mask.astype(mx.bool_), 1), 2
|
||||||
|
)
|
||||||
# Apply mask: set attention to -inf where mask is True (padded positions)
|
# Apply mask: set attention to -inf where mask is True (padded positions)
|
||||||
attn_weights = mx.where(mask, -mx.inf, attn_weights)
|
attn_weights = mx.where(mask, -mx.inf, attn_weights)
|
||||||
attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len)
|
||||||
@@ -135,16 +138,16 @@ class MultiheadAttention(nn.Module):
|
|||||||
attn_weights_out: Optional[mx.array] = None
|
attn_weights_out: Optional[mx.array] = None
|
||||||
if need_head_weights:
|
if need_head_weights:
|
||||||
# Return attention weights for each head separately
|
# Return attention weights for each head separately
|
||||||
attn_weights_out = attn_weights_float.reshape(
|
attn_weights_out = (
|
||||||
bsz, self.num_heads, tgt_len, src_len
|
attn_weights_float.reshape(bsz, self.num_heads, tgt_len, src_len)
|
||||||
).astype(attn.dtype).swapaxes(0, 1)
|
.astype(attn.dtype)
|
||||||
|
.swapaxes(0, 1)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Return averaged attention weights
|
# Return averaged attention weights
|
||||||
attn_weights_out = mx.mean(
|
attn_weights_out = mx.mean(
|
||||||
attn_weights_float.reshape(bsz, self.num_heads, tgt_len, src_len),
|
attn_weights_float.reshape(bsz, self.num_heads, tgt_len, src_len),
|
||||||
axis=1
|
axis=1,
|
||||||
).astype(attn.dtype)
|
).astype(attn.dtype)
|
||||||
|
|
||||||
return attn, attn_weights_out
|
return attn, attn_weights_out
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
from typing import List, Dict, Optional, Tuple
|
|
||||||
from pathlib import Path
|
|
||||||
import json
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .tokenizer import ProteinTokenizer
|
|
||||||
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
|
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
|
||||||
|
from .tokenizer import ProteinTokenizer
|
||||||
|
|
||||||
|
|
||||||
class ESM2(nn.Module):
|
class ESM2(nn.Module):
|
||||||
@@ -128,8 +128,12 @@ class ESM2(nn.Module):
|
|||||||
|
|
||||||
mask_ratio_train = 0.15 * 0.8
|
mask_ratio_train = 0.15 * 0.8
|
||||||
src_lengths = mx.sum(~padding_mask, axis=-1, keepdims=True)
|
src_lengths = mx.sum(~padding_mask, axis=-1, keepdims=True)
|
||||||
mask_ratio_observed = mx.sum(mask_positions, axis=-1, keepdims=True) / src_lengths
|
mask_ratio_observed = (
|
||||||
scale_factor = (1 - mask_ratio_train) / mx.maximum(1 - mask_ratio_observed, 1e-8)
|
mx.sum(mask_positions, axis=-1, keepdims=True) / src_lengths
|
||||||
|
)
|
||||||
|
scale_factor = (1 - mask_ratio_train) / mx.maximum(
|
||||||
|
1 - mask_ratio_observed, 1e-8
|
||||||
|
)
|
||||||
x = x * scale_factor[:, None, :]
|
x = x * scale_factor[:, None, :]
|
||||||
|
|
||||||
# Zero out padding positions
|
# Zero out padding positions
|
||||||
@@ -194,7 +198,9 @@ class ESM2(nn.Module):
|
|||||||
# Mask out padded positions if present
|
# Mask out padded positions if present
|
||||||
if padding_mask is not None:
|
if padding_mask is not None:
|
||||||
attention_mask = 1 - padding_mask.astype(attentions.dtype)
|
attention_mask = 1 - padding_mask.astype(attentions.dtype)
|
||||||
attention_mask = mx.expand_dims(attention_mask, 1) * mx.expand_dims(attention_mask, 2)
|
attention_mask = mx.expand_dims(attention_mask, 1) * mx.expand_dims(
|
||||||
|
attention_mask, 2
|
||||||
|
)
|
||||||
attentions = attentions * attention_mask[:, None, None, :, :]
|
attentions = attentions * attention_mask[:, None, None, :, :]
|
||||||
|
|
||||||
result["attentions"] = attentions
|
result["attentions"] = attentions
|
||||||
|
|||||||
@@ -58,7 +58,9 @@ class RotaryEmbedding(nn.Module):
|
|||||||
self._cos_cached = None
|
self._cos_cached = None
|
||||||
self._sin_cached = None
|
self._sin_cached = None
|
||||||
|
|
||||||
def _update_cos_sin_tables(self, x: mx.array, seq_dimension: int = 1) -> Tuple[mx.array, mx.array]:
|
def _update_cos_sin_tables(
|
||||||
|
self, x: mx.array, seq_dimension: int = 1
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
"""
|
"""
|
||||||
Compute and cache cos/sin tables for the given sequence length.
|
Compute and cache cos/sin tables for the given sequence length.
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,38 @@
|
|||||||
from typing import List, Sequence, Union, Optional
|
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Sequence, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
# Canonical amino-acid tokens (IUPAC standard + uncommon variants)
|
# Canonical amino-acid tokens (IUPAC standard + uncommon variants)
|
||||||
PROTEIN_TOKENS = [
|
PROTEIN_TOKENS = [
|
||||||
"L", "A", "G", "V", "S", "E", "R", "T", "I", "D",
|
"L",
|
||||||
"P", "K", "Q", "N", "F", "Y", "M", "H", "W", "C",
|
"A",
|
||||||
"X", "B", "U", "Z", "O", ".", "-"
|
"G",
|
||||||
|
"V",
|
||||||
|
"S",
|
||||||
|
"E",
|
||||||
|
"R",
|
||||||
|
"T",
|
||||||
|
"I",
|
||||||
|
"D",
|
||||||
|
"P",
|
||||||
|
"K",
|
||||||
|
"Q",
|
||||||
|
"N",
|
||||||
|
"F",
|
||||||
|
"Y",
|
||||||
|
"M",
|
||||||
|
"H",
|
||||||
|
"W",
|
||||||
|
"C",
|
||||||
|
"X",
|
||||||
|
"B",
|
||||||
|
"U",
|
||||||
|
"Z",
|
||||||
|
"O",
|
||||||
|
".",
|
||||||
|
"-",
|
||||||
]
|
]
|
||||||
|
|
||||||
ArrayLike = Union[List[int], mx.array]
|
ArrayLike = Union[List[int], mx.array]
|
||||||
@@ -201,7 +225,11 @@ class ProteinTokenizer:
|
|||||||
for i in ids_list:
|
for i in ids_list:
|
||||||
tok = self.id_to_token.get(i, "<unk>")
|
tok = self.id_to_token.get(i, "<unk>")
|
||||||
if skip_special_tokens and tok in {
|
if skip_special_tokens and tok in {
|
||||||
"<cls>", "<pad>", "<eos>", "<unk>", "<mask>"
|
"<cls>",
|
||||||
|
"<pad>",
|
||||||
|
"<eos>",
|
||||||
|
"<unk>",
|
||||||
|
"<mask>",
|
||||||
}:
|
}:
|
||||||
continue
|
continue
|
||||||
toks.append(tok)
|
toks.append(tok)
|
||||||
|
|||||||
29
esm/main.py
29
esm/main.py
@@ -1,25 +1,27 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
from esm import ESM2
|
from esm import ESM2
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="ESM-2 MLX Inference")
|
parser = argparse.ArgumentParser(description="ESM-2 MLX Inference")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model-path",
|
"--model-path",
|
||||||
default="checkpoints/mlx-esm2_t33_650M_UR50D",
|
default="checkpoints/mlx-esm2_t33_650M_UR50D",
|
||||||
help="Path to MLX model checkpoint"
|
help="Path to MLX model checkpoint",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sequence",
|
"--sequence",
|
||||||
default="MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
|
default="MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
|
||||||
help="Protein sequence to test (default: human insulin)"
|
help="Protein sequence to test (default: human insulin)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mask-position",
|
"--mask-position",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Position to mask (default: middle of sequence)"
|
help="Position to mask (default: middle of sequence)",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -28,7 +30,9 @@ def main():
|
|||||||
|
|
||||||
# Determine sequence and position to mask
|
# Determine sequence and position to mask
|
||||||
sequence = args.sequence.upper()
|
sequence = args.sequence.upper()
|
||||||
mask_pos = args.mask_position if args.mask_position is not None else len(sequence) // 2
|
mask_pos = (
|
||||||
|
args.mask_position if args.mask_position is not None else len(sequence) // 2
|
||||||
|
)
|
||||||
if mask_pos >= len(sequence):
|
if mask_pos >= len(sequence):
|
||||||
mask_pos = len(sequence) - 1
|
mask_pos = len(sequence) - 1
|
||||||
original_aa = sequence[mask_pos] # The original amino acid at masked position
|
original_aa = sequence[mask_pos] # The original amino acid at masked position
|
||||||
@@ -40,15 +44,25 @@ def main():
|
|||||||
|
|
||||||
# Tokenize sequence before and after the mask
|
# Tokenize sequence before and after the mask
|
||||||
before = tokenizer.encode(sequence[:mask_pos], add_special_tokens=False)
|
before = tokenizer.encode(sequence[:mask_pos], add_special_tokens=False)
|
||||||
after = tokenizer.encode(sequence[mask_pos+1:], add_special_tokens=False)
|
after = tokenizer.encode(sequence[mask_pos + 1 :], add_special_tokens=False)
|
||||||
|
|
||||||
# Build token sequence with <cls>, <mask>, and <eos>
|
# Build token sequence with <cls>, <mask>, and <eos>
|
||||||
tokens = mx.array([[tokenizer.cls_id] + before.tolist() + [tokenizer.mask_id] + after.tolist() + [tokenizer.eos_id]])
|
tokens = mx.array(
|
||||||
|
[
|
||||||
|
[tokenizer.cls_id]
|
||||||
|
+ before.tolist()
|
||||||
|
+ [tokenizer.mask_id]
|
||||||
|
+ after.tolist()
|
||||||
|
+ [tokenizer.eos_id]
|
||||||
|
]
|
||||||
|
)
|
||||||
mask_token_pos = 1 + len(before) # Position of <mask> token
|
mask_token_pos = 1 + len(before) # Position of <mask> token
|
||||||
|
|
||||||
# Run model to get logits for each token position
|
# Run model to get logits for each token position
|
||||||
logits = model(tokens)["logits"]
|
logits = model(tokens)["logits"]
|
||||||
probs = mx.softmax(logits[0, mask_token_pos, :]) # Softmax over vocabulary at mask position
|
probs = mx.softmax(
|
||||||
|
logits[0, mask_token_pos, :]
|
||||||
|
) # Softmax over vocabulary at mask position
|
||||||
|
|
||||||
# Get top-5 most likely tokens
|
# Get top-5 most likely tokens
|
||||||
top_indices = mx.argsort(probs)[-5:][::-1]
|
top_indices = mx.argsort(probs)[-5:][::-1]
|
||||||
@@ -62,5 +76,6 @@ def main():
|
|||||||
marker = "✓" if token == original_aa else " "
|
marker = "✓" if token == original_aa else " "
|
||||||
print(f"{marker} {i+1}. {token}: {prob:.3f} ({prob*100:.1f}%)")
|
print(f"{marker} {i+1}. {token}: {prob:.3f} ({prob*100:.1f}%)")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
48
esm/test.py
48
esm/test.py
@@ -1,6 +1,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from transformers import AutoTokenizer, EsmForMaskedLM, EsmConfig
|
from transformers import AutoTokenizer, EsmConfig, EsmForMaskedLM
|
||||||
|
|
||||||
from esm import ESM2
|
from esm import ESM2
|
||||||
|
|
||||||
@@ -8,22 +9,23 @@ from esm import ESM2
|
|||||||
MLX_PATH = "checkpoints/mlx-esm2_t12_35M_UR50D"
|
MLX_PATH = "checkpoints/mlx-esm2_t12_35M_UR50D"
|
||||||
HF_PATH = "facebook/esm2_t12_35M_UR50D"
|
HF_PATH = "facebook/esm2_t12_35M_UR50D"
|
||||||
|
|
||||||
|
|
||||||
def load_mlx_model():
|
def load_mlx_model():
|
||||||
"""Load MLX ESM-2 model and tokenizer."""
|
"""Load MLX ESM-2 model and tokenizer."""
|
||||||
tokenizer, model = ESM2.from_pretrained(MLX_PATH)
|
tokenizer, model = ESM2.from_pretrained(MLX_PATH)
|
||||||
return tokenizer, model
|
return tokenizer, model
|
||||||
|
|
||||||
|
|
||||||
def load_hf_model():
|
def load_hf_model():
|
||||||
"""Load Hugging Face ESM-2 model and tokenizer with hidden states + attentions."""
|
"""Load Hugging Face ESM-2 model and tokenizer with hidden states + attentions."""
|
||||||
tokenizer = AutoTokenizer.from_pretrained(HF_PATH)
|
tokenizer = AutoTokenizer.from_pretrained(HF_PATH)
|
||||||
config = EsmConfig.from_pretrained(
|
config = EsmConfig.from_pretrained(
|
||||||
HF_PATH,
|
HF_PATH, output_hidden_states=True, output_attentions=True
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True
|
|
||||||
)
|
)
|
||||||
model = EsmForMaskedLM.from_pretrained(HF_PATH, config=config)
|
model = EsmForMaskedLM.from_pretrained(HF_PATH, config=config)
|
||||||
return tokenizer, model
|
return tokenizer, model
|
||||||
|
|
||||||
|
|
||||||
class TestESM2(unittest.TestCase):
|
class TestESM2(unittest.TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@@ -37,12 +39,16 @@ class TestESM2(unittest.TestCase):
|
|||||||
|
|
||||||
sequences = [
|
sequences = [
|
||||||
"MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
|
"MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
|
||||||
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"
|
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Compare batched tokenization (padded sequences)
|
# Compare batched tokenization (padded sequences)
|
||||||
mlx_batch = self.mlx_tokenizer.batch_encode(sequences)
|
mlx_batch = self.mlx_tokenizer.batch_encode(sequences)
|
||||||
hf_batch = self.hf_tokenizer(sequences, return_tensors="pt", padding=True)["input_ids"].cpu().numpy()
|
hf_batch = (
|
||||||
|
self.hf_tokenizer(sequences, return_tensors="pt", padding=True)["input_ids"]
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
self.assertEqual(tuple(mlx_batch.shape), tuple(hf_batch.shape))
|
self.assertEqual(tuple(mlx_batch.shape), tuple(hf_batch.shape))
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
np.array_equal(np.array(mlx_batch.tolist(), dtype=hf_batch.dtype), hf_batch)
|
np.array_equal(np.array(mlx_batch.tolist(), dtype=hf_batch.dtype), hf_batch)
|
||||||
@@ -51,22 +57,29 @@ class TestESM2(unittest.TestCase):
|
|||||||
# Compare single-sequence encode/decode
|
# Compare single-sequence encode/decode
|
||||||
for sequence in sequences:
|
for sequence in sequences:
|
||||||
mlx_tokens = self.mlx_tokenizer.encode(sequence)
|
mlx_tokens = self.mlx_tokenizer.encode(sequence)
|
||||||
hf_tokens = self.hf_tokenizer(sequence, return_tensors="pt")["input_ids"].cpu().numpy().tolist()[0]
|
hf_tokens = (
|
||||||
|
self.hf_tokenizer(sequence, return_tensors="pt")["input_ids"]
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
.tolist()[0]
|
||||||
|
)
|
||||||
self.assertTrue(np.array_equal(mlx_tokens, hf_tokens))
|
self.assertTrue(np.array_equal(mlx_tokens, hf_tokens))
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.mlx_tokenizer.decode(mlx_tokens),
|
self.mlx_tokenizer.decode(mlx_tokens),
|
||||||
self.hf_tokenizer.decode(hf_tokens).replace(" ", "")
|
self.hf_tokenizer.decode(hf_tokens).replace(" ", ""),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.mlx_tokenizer.decode(mlx_tokens, skip_special_tokens=True),
|
self.mlx_tokenizer.decode(mlx_tokens, skip_special_tokens=True),
|
||||||
self.hf_tokenizer.decode(hf_tokens, skip_special_tokens=True).replace(" ", "")
|
self.hf_tokenizer.decode(hf_tokens, skip_special_tokens=True).replace(
|
||||||
|
" ", ""
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_model(self):
|
def test_model(self):
|
||||||
"""Verify MLX and HF model outputs match (logits, hidden states, attentions)."""
|
"""Verify MLX and HF model outputs match (logits, hidden states, attentions)."""
|
||||||
sequences = [
|
sequences = [
|
||||||
"MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
|
"MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
|
||||||
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"
|
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
|
||||||
]
|
]
|
||||||
for sequence in sequences:
|
for sequence in sequences:
|
||||||
# Tokenize
|
# Tokenize
|
||||||
@@ -77,7 +90,7 @@ class TestESM2(unittest.TestCase):
|
|||||||
mlx_outputs = self.mlx_model(
|
mlx_outputs = self.mlx_model(
|
||||||
mlx_tokens,
|
mlx_tokens,
|
||||||
repr_layers=[self.mlx_model.num_layers],
|
repr_layers=[self.mlx_model.num_layers],
|
||||||
need_head_weights=True
|
need_head_weights=True,
|
||||||
)
|
)
|
||||||
hf_outputs = self.hf_model(input_ids=hf_tokens)
|
hf_outputs = self.hf_model(input_ids=hf_tokens)
|
||||||
|
|
||||||
@@ -90,12 +103,19 @@ class TestESM2(unittest.TestCase):
|
|||||||
final_layer = self.mlx_model.num_layers
|
final_layer = self.mlx_model.num_layers
|
||||||
mlx_hidden_states = np.array(mlx_outputs["representations"][final_layer])
|
mlx_hidden_states = np.array(mlx_outputs["representations"][final_layer])
|
||||||
hf_hidden_states = hf_outputs["hidden_states"][-1].detach().cpu().numpy()
|
hf_hidden_states = hf_outputs["hidden_states"][-1].detach().cpu().numpy()
|
||||||
self.assertTrue(np.allclose(mlx_hidden_states, hf_hidden_states, rtol=1e-4, atol=1e-4))
|
self.assertTrue(
|
||||||
|
np.allclose(mlx_hidden_states, hf_hidden_states, rtol=1e-4, atol=1e-4)
|
||||||
|
)
|
||||||
|
|
||||||
# Compare attentions for final layer
|
# Compare attentions for final layer
|
||||||
mlx_attentions = np.array(mlx_outputs["attentions"][:, final_layer-1, :, :, :])
|
mlx_attentions = np.array(
|
||||||
|
mlx_outputs["attentions"][:, final_layer - 1, :, :, :]
|
||||||
|
)
|
||||||
hf_attentions = hf_outputs["attentions"][-1].detach().cpu().numpy()
|
hf_attentions = hf_outputs["attentions"][-1].detach().cpu().numpy()
|
||||||
self.assertTrue(np.allclose(mlx_attentions, hf_attentions, rtol=1e-4, atol=1e-4))
|
self.assertTrue(
|
||||||
|
np.allclose(mlx_attentions, hf_attentions, rtol=1e-4, atol=1e-4)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user