mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Formatted code
This commit is contained in:
@@ -44,4 +44,4 @@ throughput = batch_size * 1000 / ms_per_step
|
||||
|
||||
# Display results
|
||||
print(f"Time (ms) per step: {ms_per_step:.3f}")
|
||||
print(f"Throughput: {throughput:.2f} sequences/sec")
|
||||
print(f"Throughput: {throughput:.2f} sequences/sec")
|
||||
|
||||
@@ -22,11 +22,11 @@ steps = 50
|
||||
# Tokenize input sequence and replicate for the batch
|
||||
# Replicate the same sequence 'batch_size' times to create a batch
|
||||
inputs = tokenizer(
|
||||
[protein_sequence] * batch_size,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
[protein_sequence] * batch_size,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=1024
|
||||
max_length=1024,
|
||||
)
|
||||
input_ids = inputs["input_ids"].to("mps")
|
||||
attention_mask = inputs["attention_mask"].to("mps")
|
||||
@@ -49,4 +49,4 @@ throughput = batch_size * 1000 / ms_per_step
|
||||
|
||||
# Report results
|
||||
print(f"Time (ms) per step: {ms_per_step:.3f}")
|
||||
print(f"Throughput: {throughput:.2f} sequences/sec")
|
||||
print(f"Throughput: {throughput:.2f} sequences/sec")
|
||||
|
||||
@@ -8,6 +8,7 @@ import mlx.core as mx
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
def download(hf_repo: str) -> Path:
|
||||
"""Download model from Hugging Face."""
|
||||
return Path(
|
||||
@@ -17,13 +18,14 @@ def download(hf_repo: str) -> Path:
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def remap_key(key: str) -> str:
|
||||
"""Remap HuggingFace ESM key names to MLX format."""
|
||||
|
||||
# Skip position embeddings and position_ids
|
||||
if "position_embeddings" in key or "position_ids" in key:
|
||||
return None
|
||||
|
||||
|
||||
# Map lm_head components properly
|
||||
if key == "lm_head.decoder.weight":
|
||||
return "lm_head.weight"
|
||||
@@ -37,14 +39,14 @@ def remap_key(key: str) -> str:
|
||||
return "lm_head.layer_norm.weight"
|
||||
if key == "lm_head.layer_norm.bias":
|
||||
return "lm_head.layer_norm.bias"
|
||||
|
||||
|
||||
# Core remapping patterns
|
||||
key = key.replace("esm.embeddings.word_embeddings", "embed_tokens")
|
||||
key = key.replace("esm.encoder.emb_layer_norm_after", "emb_layer_norm_after")
|
||||
key = key.replace("esm.encoder.layer.", "layer_")
|
||||
key = key.replace("esm.contact_head", "contact_head")
|
||||
key = key.replace("lm_head", "lm_head")
|
||||
|
||||
|
||||
# Attention patterns
|
||||
key = key.replace(".attention.self.", ".self_attn.")
|
||||
key = key.replace(".attention.output.dense", ".self_attn.out_proj")
|
||||
@@ -53,12 +55,12 @@ def remap_key(key: str) -> str:
|
||||
key = key.replace(".key", ".k_proj")
|
||||
key = key.replace(".value", ".v_proj")
|
||||
key = key.replace(".rotary_embeddings", ".rot_emb")
|
||||
|
||||
|
||||
# FFN patterns
|
||||
key = key.replace(".intermediate.dense", ".fc1")
|
||||
key = key.replace(".output.dense", ".fc2")
|
||||
key = key.replace(".LayerNorm", ".final_layer_norm")
|
||||
|
||||
|
||||
return key
|
||||
|
||||
|
||||
@@ -70,24 +72,24 @@ def load_weights(model_path: Path) -> Dict:
|
||||
if safetensors_path.exists():
|
||||
print("Loading from safetensors...")
|
||||
return mx.load(str(safetensors_path))
|
||||
|
||||
|
||||
# Check for single bin file
|
||||
single_bin_path = model_path / "pytorch_model.bin"
|
||||
if single_bin_path.exists():
|
||||
print("Loading from pytorch_model.bin...")
|
||||
state_dict = torch.load(str(single_bin_path), map_location="cpu")
|
||||
return {k: v.numpy() for k, v in state_dict.items()}
|
||||
|
||||
|
||||
# Check for sharded bin files
|
||||
index_file = model_path / "pytorch_model.bin.index.json"
|
||||
if index_file.exists():
|
||||
print("Loading from sharded bin files...")
|
||||
with open(index_file) as f:
|
||||
index = json.load(f)
|
||||
|
||||
|
||||
# Get unique shard files
|
||||
shard_files = set(index["weight_map"].values())
|
||||
|
||||
|
||||
# Load all shards
|
||||
state_dict = {}
|
||||
for shard_file in sorted(shard_files):
|
||||
@@ -95,9 +97,9 @@ def load_weights(model_path: Path) -> Dict:
|
||||
shard_path = model_path / shard_file
|
||||
shard_dict = torch.load(str(shard_path), map_location="cpu")
|
||||
state_dict.update(shard_dict)
|
||||
|
||||
|
||||
return {k: v.numpy() for k, v in state_dict.items()}
|
||||
|
||||
|
||||
raise ValueError(f"No model weights found in {model_path}")
|
||||
|
||||
|
||||
@@ -106,38 +108,34 @@ def convert(model_path: Path) -> Dict[str, mx.array]:
|
||||
|
||||
# Load weights from any format
|
||||
weights = load_weights(model_path)
|
||||
|
||||
|
||||
# Convert keys and create MLX arrays
|
||||
mlx_weights = {}
|
||||
for key, value in weights.items():
|
||||
mlx_key = remap_key(key)
|
||||
if mlx_key is not None:
|
||||
mlx_weights[mlx_key] = mx.array(value) if not isinstance(value, mx.array) else value
|
||||
|
||||
mlx_weights[mlx_key] = (
|
||||
mx.array(value) if not isinstance(value, mx.array) else value
|
||||
)
|
||||
|
||||
# If lm_head.weight is missing but embed_tokens.weight exists, set up weight sharing
|
||||
# (This is for smaller models that don't have a separate lm_head.decoder.weight)
|
||||
if "lm_head.weight" not in mlx_weights and "embed_tokens.weight" in mlx_weights:
|
||||
mlx_weights["lm_head.weight"] = mlx_weights["embed_tokens.weight"]
|
||||
|
||||
|
||||
return mlx_weights
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Convert ESM weights to MLX format")
|
||||
parser.add_argument(
|
||||
"--hf-path",
|
||||
default="facebook/esm2_t6_8M_UR50D",
|
||||
help="Hugging Face model path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlx-path",
|
||||
default=None,
|
||||
help="Output path for MLX model"
|
||||
"--hf-path", default="facebook/esm2_t6_8M_UR50D", help="Hugging Face model path"
|
||||
)
|
||||
parser.add_argument("--mlx-path", default=None, help="Output path for MLX model")
|
||||
parser.add_argument(
|
||||
"--checkpoints-dir",
|
||||
default="checkpoints",
|
||||
help="Directory to save checkpoints (default: checkpoints)"
|
||||
help="Directory to save checkpoints (default: checkpoints)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -145,7 +143,7 @@ def main():
|
||||
# Download model
|
||||
print(f"Downloading {args.hf_path}...")
|
||||
model_path = download(args.hf_path)
|
||||
|
||||
|
||||
# Set output path
|
||||
if args.mlx_path is None:
|
||||
model_name = args.hf_path.split("/")[-1]
|
||||
@@ -154,25 +152,26 @@ def main():
|
||||
args.mlx_path = checkpoints_dir / f"mlx-{model_name}"
|
||||
mlx_path = Path(args.mlx_path)
|
||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Convert weights
|
||||
print("Converting weights...")
|
||||
mlx_weights = convert(model_path)
|
||||
|
||||
|
||||
# Save weights
|
||||
print(f"Saving MLX weights to {mlx_path}...")
|
||||
mx.save_safetensors(str(mlx_path / "model.safetensors"), mlx_weights)
|
||||
|
||||
|
||||
# Copy config and other files
|
||||
print("Copying config...")
|
||||
shutil.copy(model_path / "config.json", mlx_path / "config.json")
|
||||
|
||||
|
||||
for file_name in ["special_tokens_map.json", "tokenizer.json", "vocab.txt"]:
|
||||
src_file = model_path / file_name
|
||||
if src_file.exists():
|
||||
shutil.copy(src_file, mlx_path / file_name)
|
||||
|
||||
|
||||
print(f"Conversion complete! MLX model saved to {mlx_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@@ -2,18 +2,18 @@
|
||||
ESM-2 protein language model implementation in MLX
|
||||
"""
|
||||
|
||||
from .model import ESM2
|
||||
from .tokenizer import ProteinTokenizer
|
||||
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
|
||||
from .attention import MultiheadAttention
|
||||
from .model import ESM2
|
||||
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
|
||||
from .rotary_embedding import RotaryEmbedding
|
||||
from .tokenizer import ProteinTokenizer
|
||||
|
||||
__all__ = [
|
||||
'ESM2',
|
||||
'ProteinTokenizer',
|
||||
'ContactPredictionHead',
|
||||
'RobertaLMHead',
|
||||
'TransformerLayer',
|
||||
'MultiheadAttention',
|
||||
'RotaryEmbedding'
|
||||
"ESM2",
|
||||
"ProteinTokenizer",
|
||||
"ContactPredictionHead",
|
||||
"RobertaLMHead",
|
||||
"TransformerLayer",
|
||||
"MultiheadAttention",
|
||||
"RotaryEmbedding",
|
||||
]
|
||||
|
||||
@@ -5,14 +5,15 @@ import mlx.nn as nn
|
||||
|
||||
from .rotary_embedding import RotaryEmbedding
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
"""
|
||||
Multi-head attention layer with rotary position embeddings, as used in ESM-2.
|
||||
|
||||
This module implements both self-attention (when `key` and `value` are not
|
||||
provided) and cross-attention. It projects input sequences into queries,
|
||||
keys, and values, applies rotary position embeddings to encode relative
|
||||
position information, computes scaled dot-product attention over multiple
|
||||
This module implements both self-attention (when `key` and `value` are not
|
||||
provided) and cross-attention. It projects input sequences into queries,
|
||||
keys, and values, applies rotary position embeddings to encode relative
|
||||
position information, computes scaled dot-product attention over multiple
|
||||
heads in parallel, and returns a combined output projection.
|
||||
|
||||
Args:
|
||||
@@ -56,7 +57,7 @@ class MultiheadAttention(nn.Module):
|
||||
) -> Tuple[mx.array, Optional[mx.array]]:
|
||||
"""
|
||||
Multi-head attention forward pass.
|
||||
|
||||
|
||||
Args:
|
||||
query: Tensor of shape (tgt_len, batch, embed_dim).
|
||||
key: Optional tensor of shape (src_len, batch, embed_dim). Defaults to `query`.
|
||||
@@ -67,32 +68,32 @@ class MultiheadAttention(nn.Module):
|
||||
|
||||
Returns:
|
||||
attn_output: Tensor of shape (tgt_len, batch, embed_dim).
|
||||
attn_weights_out: Attention weights of shape
|
||||
attn_weights_out: Attention weights of shape
|
||||
(num_heads, batch, tgt_len, src_len) if per-head,
|
||||
or (batch, tgt_len, src_len) if averaged.
|
||||
"""
|
||||
|
||||
|
||||
tgt_len, bsz, embed_dim = query.shape
|
||||
assert embed_dim == self.embed_dim
|
||||
|
||||
|
||||
# For self-attention, use query as key and value if not provided
|
||||
if key is None:
|
||||
key = query
|
||||
if value is None:
|
||||
value = query
|
||||
|
||||
|
||||
# Project queries, keys, values
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(value)
|
||||
|
||||
|
||||
q = q * self.scaling
|
||||
|
||||
# Reshape for multi-head attention
|
||||
q = q.reshape(tgt_len, bsz * self.num_heads, self.head_dim).swapaxes(0, 1)
|
||||
k = k.reshape(-1, bsz * self.num_heads, self.head_dim).swapaxes(0, 1)
|
||||
v = v.reshape(-1, bsz * self.num_heads, self.head_dim).swapaxes(0, 1)
|
||||
|
||||
|
||||
src_len = k.shape[1]
|
||||
|
||||
# Apply rotary embeddings if present
|
||||
@@ -114,7 +115,9 @@ class MultiheadAttention(nn.Module):
|
||||
attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len)
|
||||
# Convert key_padding_mask to boolean and expand dimensions
|
||||
# key_padding_mask: [bsz, src_len] -> [bsz, 1, 1, src_len]
|
||||
mask = mx.expand_dims(mx.expand_dims(key_padding_mask.astype(mx.bool_), 1), 2)
|
||||
mask = mx.expand_dims(
|
||||
mx.expand_dims(key_padding_mask.astype(mx.bool_), 1), 2
|
||||
)
|
||||
# Apply mask: set attention to -inf where mask is True (padded positions)
|
||||
attn_weights = mx.where(mask, -mx.inf, attn_weights)
|
||||
attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len)
|
||||
@@ -126,25 +129,25 @@ class MultiheadAttention(nn.Module):
|
||||
# Compute attention output
|
||||
attn = attn_probs @ v
|
||||
assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
||||
|
||||
|
||||
# Reshape output
|
||||
attn = attn.swapaxes(0, 1).reshape(tgt_len, bsz, embed_dim)
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
|
||||
# Return attention weights if requested
|
||||
attn_weights_out: Optional[mx.array] = None
|
||||
if need_head_weights:
|
||||
# Return attention weights for each head separately
|
||||
attn_weights_out = attn_weights_float.reshape(
|
||||
bsz, self.num_heads, tgt_len, src_len
|
||||
).astype(attn.dtype).swapaxes(0, 1)
|
||||
attn_weights_out = (
|
||||
attn_weights_float.reshape(bsz, self.num_heads, tgt_len, src_len)
|
||||
.astype(attn.dtype)
|
||||
.swapaxes(0, 1)
|
||||
)
|
||||
else:
|
||||
# Return averaged attention weights
|
||||
attn_weights_out = mx.mean(
|
||||
attn_weights_float.reshape(bsz, self.num_heads, tgt_len, src_len),
|
||||
axis=1
|
||||
axis=1,
|
||||
).astype(attn.dtype)
|
||||
|
||||
return attn, attn_weights_out
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from pathlib import Path
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .tokenizer import ProteinTokenizer
|
||||
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
|
||||
from .tokenizer import ProteinTokenizer
|
||||
|
||||
|
||||
class ESM2(nn.Module):
|
||||
@@ -40,7 +40,7 @@ class ESM2(nn.Module):
|
||||
self.tokenizer = tokenizer
|
||||
self.vocab_size = len(tokenizer)
|
||||
|
||||
# Special token IDs / config
|
||||
# Special token IDs / config
|
||||
self.padding_idx = tokenizer.pad_id
|
||||
self.mask_idx = tokenizer.mask_id
|
||||
self.cls_idx = tokenizer.cls_id
|
||||
@@ -128,8 +128,12 @@ class ESM2(nn.Module):
|
||||
|
||||
mask_ratio_train = 0.15 * 0.8
|
||||
src_lengths = mx.sum(~padding_mask, axis=-1, keepdims=True)
|
||||
mask_ratio_observed = mx.sum(mask_positions, axis=-1, keepdims=True) / src_lengths
|
||||
scale_factor = (1 - mask_ratio_train) / mx.maximum(1 - mask_ratio_observed, 1e-8)
|
||||
mask_ratio_observed = (
|
||||
mx.sum(mask_positions, axis=-1, keepdims=True) / src_lengths
|
||||
)
|
||||
scale_factor = (1 - mask_ratio_train) / mx.maximum(
|
||||
1 - mask_ratio_observed, 1e-8
|
||||
)
|
||||
x = x * scale_factor[:, None, :]
|
||||
|
||||
# Zero out padding positions
|
||||
@@ -194,7 +198,9 @@ class ESM2(nn.Module):
|
||||
# Mask out padded positions if present
|
||||
if padding_mask is not None:
|
||||
attention_mask = 1 - padding_mask.astype(attentions.dtype)
|
||||
attention_mask = mx.expand_dims(attention_mask, 1) * mx.expand_dims(attention_mask, 2)
|
||||
attention_mask = mx.expand_dims(attention_mask, 1) * mx.expand_dims(
|
||||
attention_mask, 2
|
||||
)
|
||||
attentions = attentions * attention_mask[:, None, None, :, :]
|
||||
|
||||
result["attentions"] = attentions
|
||||
@@ -304,7 +310,7 @@ class ESM2(nn.Module):
|
||||
# Check for vocab and special tokens files
|
||||
vocab_path = model_dir / "vocab.txt"
|
||||
special_tokens_path = model_dir / "special_tokens_map.json"
|
||||
|
||||
|
||||
if vocab_path.exists() and special_tokens_path.exists():
|
||||
tokenizer = ProteinTokenizer(
|
||||
vocab_file=str(vocab_path),
|
||||
@@ -333,4 +339,4 @@ class ESM2(nn.Module):
|
||||
cur[parts[-1]] = value
|
||||
|
||||
model.update(nested_weights)
|
||||
return tokenizer, model
|
||||
return tokenizer, model
|
||||
|
||||
@@ -58,7 +58,9 @@ class RotaryEmbedding(nn.Module):
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
|
||||
def _update_cos_sin_tables(self, x: mx.array, seq_dimension: int = 1) -> Tuple[mx.array, mx.array]:
|
||||
def _update_cos_sin_tables(
|
||||
self, x: mx.array, seq_dimension: int = 1
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""
|
||||
Compute and cache cos/sin tables for the given sequence length.
|
||||
|
||||
@@ -109,4 +111,4 @@ class RotaryEmbedding(nn.Module):
|
||||
return (
|
||||
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
|
||||
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,14 +1,38 @@
|
||||
from typing import List, Sequence, Union, Optional
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
# Canonical amino-acid tokens (IUPAC standard + uncommon variants)
|
||||
PROTEIN_TOKENS = [
|
||||
"L", "A", "G", "V", "S", "E", "R", "T", "I", "D",
|
||||
"P", "K", "Q", "N", "F", "Y", "M", "H", "W", "C",
|
||||
"X", "B", "U", "Z", "O", ".", "-"
|
||||
"L",
|
||||
"A",
|
||||
"G",
|
||||
"V",
|
||||
"S",
|
||||
"E",
|
||||
"R",
|
||||
"T",
|
||||
"I",
|
||||
"D",
|
||||
"P",
|
||||
"K",
|
||||
"Q",
|
||||
"N",
|
||||
"F",
|
||||
"Y",
|
||||
"M",
|
||||
"H",
|
||||
"W",
|
||||
"C",
|
||||
"X",
|
||||
"B",
|
||||
"U",
|
||||
"Z",
|
||||
"O",
|
||||
".",
|
||||
"-",
|
||||
]
|
||||
|
||||
ArrayLike = Union[List[int], mx.array]
|
||||
@@ -39,7 +63,7 @@ class ProteinTokenizer:
|
||||
If both files are provided, they override the default vocabulary and
|
||||
special token mappings. Otherwise, defaults are loaded.
|
||||
"""
|
||||
|
||||
|
||||
# Load vocabulary from files if given, otherwise use built-in defaults
|
||||
if vocab_file and special_tokens_map_file:
|
||||
self._load_from_files(vocab_file, special_tokens_map_file)
|
||||
@@ -51,15 +75,15 @@ class ProteinTokenizer:
|
||||
self.id_to_token = {i: tok for i, tok in enumerate(self.vocab)}
|
||||
|
||||
# Cache special token IDs
|
||||
self.cls_id = self.token_to_id["<cls>"]
|
||||
self.pad_id = self.token_to_id["<pad>"]
|
||||
self.eos_id = self.token_to_id["<eos>"]
|
||||
self.unk_id = self.token_to_id["<unk>"]
|
||||
self.cls_id = self.token_to_id["<cls>"]
|
||||
self.pad_id = self.token_to_id["<pad>"]
|
||||
self.eos_id = self.token_to_id["<eos>"]
|
||||
self.unk_id = self.token_to_id["<unk>"]
|
||||
self.mask_id = self.token_to_id["<mask>"]
|
||||
|
||||
# Behavior flags for ESM-2-style BOS/EOS
|
||||
self.prepend_bos = True
|
||||
self.append_eos = True
|
||||
self.append_eos = True
|
||||
|
||||
def _load_from_files(self, vocab_file: str, special_tokens_map_file: str) -> None:
|
||||
"""Load vocabulary and special tokens from the provided files."""
|
||||
@@ -201,7 +225,11 @@ class ProteinTokenizer:
|
||||
for i in ids_list:
|
||||
tok = self.id_to_token.get(i, "<unk>")
|
||||
if skip_special_tokens and tok in {
|
||||
"<cls>", "<pad>", "<eos>", "<unk>", "<mask>"
|
||||
"<cls>",
|
||||
"<pad>",
|
||||
"<eos>",
|
||||
"<unk>",
|
||||
"<mask>",
|
||||
}:
|
||||
continue
|
||||
toks.append(tok)
|
||||
|
||||
55
esm/main.py
55
esm/main.py
@@ -1,58 +1,72 @@
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from esm import ESM2
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ESM-2 MLX Inference")
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
"--model-path",
|
||||
default="checkpoints/mlx-esm2_t33_650M_UR50D",
|
||||
help="Path to MLX model checkpoint"
|
||||
help="Path to MLX model checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sequence",
|
||||
"--sequence",
|
||||
default="MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
|
||||
help="Protein sequence to test (default: human insulin)"
|
||||
help="Protein sequence to test (default: human insulin)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mask-position",
|
||||
type=int,
|
||||
"--mask-position",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Position to mask (default: middle of sequence)"
|
||||
help="Position to mask (default: middle of sequence)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Load pretrained ESM-2 model and tokenizer
|
||||
tokenizer, model = ESM2.from_pretrained(args.model_path)
|
||||
|
||||
|
||||
# Determine sequence and position to mask
|
||||
sequence = args.sequence.upper()
|
||||
mask_pos = args.mask_position if args.mask_position is not None else len(sequence) // 2
|
||||
mask_pos = (
|
||||
args.mask_position if args.mask_position is not None else len(sequence) // 2
|
||||
)
|
||||
if mask_pos >= len(sequence):
|
||||
mask_pos = len(sequence) - 1
|
||||
original_aa = sequence[mask_pos] # The original amino acid at masked position
|
||||
|
||||
|
||||
# Display input info
|
||||
print(f"Original sequence: {sequence}")
|
||||
print(f"Masked sequence: {sequence[:mask_pos]}<mask>{sequence[mask_pos+1:]}")
|
||||
print(f"Predicting position {mask_pos}: {original_aa}\n")
|
||||
|
||||
|
||||
# Tokenize sequence before and after the mask
|
||||
before = tokenizer.encode(sequence[:mask_pos], add_special_tokens=False)
|
||||
after = tokenizer.encode(sequence[mask_pos+1:], add_special_tokens=False)
|
||||
|
||||
after = tokenizer.encode(sequence[mask_pos + 1 :], add_special_tokens=False)
|
||||
|
||||
# Build token sequence with <cls>, <mask>, and <eos>
|
||||
tokens = mx.array([[tokenizer.cls_id] + before.tolist() + [tokenizer.mask_id] + after.tolist() + [tokenizer.eos_id]])
|
||||
tokens = mx.array(
|
||||
[
|
||||
[tokenizer.cls_id]
|
||||
+ before.tolist()
|
||||
+ [tokenizer.mask_id]
|
||||
+ after.tolist()
|
||||
+ [tokenizer.eos_id]
|
||||
]
|
||||
)
|
||||
mask_token_pos = 1 + len(before) # Position of <mask> token
|
||||
|
||||
|
||||
# Run model to get logits for each token position
|
||||
logits = model(tokens)["logits"]
|
||||
probs = mx.softmax(logits[0, mask_token_pos, :]) # Softmax over vocabulary at mask position
|
||||
|
||||
probs = mx.softmax(
|
||||
logits[0, mask_token_pos, :]
|
||||
) # Softmax over vocabulary at mask position
|
||||
|
||||
# Get top-5 most likely tokens
|
||||
top_indices = mx.argsort(probs)[-5:][::-1]
|
||||
|
||||
|
||||
# Print predictions
|
||||
print("Top predictions:")
|
||||
for i, idx in enumerate(top_indices):
|
||||
@@ -62,5 +76,6 @@ def main():
|
||||
marker = "✓" if token == original_aa else " "
|
||||
print(f"{marker} {i+1}. {token}: {prob:.3f} ({prob*100:.1f}%)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
56
esm/test.py
56
esm/test.py
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer, EsmForMaskedLM, EsmConfig
|
||||
from transformers import AutoTokenizer, EsmConfig, EsmForMaskedLM
|
||||
|
||||
from esm import ESM2
|
||||
|
||||
@@ -8,65 +9,77 @@ from esm import ESM2
|
||||
MLX_PATH = "checkpoints/mlx-esm2_t12_35M_UR50D"
|
||||
HF_PATH = "facebook/esm2_t12_35M_UR50D"
|
||||
|
||||
|
||||
def load_mlx_model():
|
||||
"""Load MLX ESM-2 model and tokenizer."""
|
||||
tokenizer, model = ESM2.from_pretrained(MLX_PATH)
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
def load_hf_model():
|
||||
"""Load Hugging Face ESM-2 model and tokenizer with hidden states + attentions."""
|
||||
tokenizer = AutoTokenizer.from_pretrained(HF_PATH)
|
||||
config = EsmConfig.from_pretrained(
|
||||
HF_PATH,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True
|
||||
HF_PATH, output_hidden_states=True, output_attentions=True
|
||||
)
|
||||
model = EsmForMaskedLM.from_pretrained(HF_PATH, config=config)
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
class TestESM2(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Load both MLX and HF models/tokenizers once for all tests
|
||||
cls.mlx_tokenizer, cls.mlx_model = load_mlx_model()
|
||||
cls.hf_tokenizer, cls.hf_model = load_hf_model()
|
||||
|
||||
|
||||
def test_tokenizer(self):
|
||||
"""Verify MLX tokenizer matches Hugging Face tokenizer behavior."""
|
||||
self.assertEqual(len(self.mlx_tokenizer), len(self.hf_tokenizer))
|
||||
|
||||
sequences = [
|
||||
"MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
|
||||
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"
|
||||
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
|
||||
]
|
||||
|
||||
# Compare batched tokenization (padded sequences)
|
||||
mlx_batch = self.mlx_tokenizer.batch_encode(sequences)
|
||||
hf_batch = self.hf_tokenizer(sequences, return_tensors="pt", padding=True)["input_ids"].cpu().numpy()
|
||||
hf_batch = (
|
||||
self.hf_tokenizer(sequences, return_tensors="pt", padding=True)["input_ids"]
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
self.assertEqual(tuple(mlx_batch.shape), tuple(hf_batch.shape))
|
||||
self.assertTrue(
|
||||
np.array_equal(np.array(mlx_batch.tolist(), dtype=hf_batch.dtype), hf_batch)
|
||||
)
|
||||
|
||||
|
||||
# Compare single-sequence encode/decode
|
||||
for sequence in sequences:
|
||||
mlx_tokens = self.mlx_tokenizer.encode(sequence)
|
||||
hf_tokens = self.hf_tokenizer(sequence, return_tensors="pt")["input_ids"].cpu().numpy().tolist()[0]
|
||||
hf_tokens = (
|
||||
self.hf_tokenizer(sequence, return_tensors="pt")["input_ids"]
|
||||
.cpu()
|
||||
.numpy()
|
||||
.tolist()[0]
|
||||
)
|
||||
self.assertTrue(np.array_equal(mlx_tokens, hf_tokens))
|
||||
self.assertEqual(
|
||||
self.mlx_tokenizer.decode(mlx_tokens),
|
||||
self.hf_tokenizer.decode(hf_tokens).replace(" ", "")
|
||||
self.hf_tokenizer.decode(hf_tokens).replace(" ", ""),
|
||||
)
|
||||
self.assertEqual(
|
||||
self.mlx_tokenizer.decode(mlx_tokens, skip_special_tokens=True),
|
||||
self.hf_tokenizer.decode(hf_tokens, skip_special_tokens=True).replace(" ", "")
|
||||
self.hf_tokenizer.decode(hf_tokens, skip_special_tokens=True).replace(
|
||||
" ", ""
|
||||
),
|
||||
)
|
||||
|
||||
def test_model(self):
|
||||
"""Verify MLX and HF model outputs match (logits, hidden states, attentions)."""
|
||||
sequences = [
|
||||
"MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
|
||||
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"
|
||||
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
|
||||
]
|
||||
for sequence in sequences:
|
||||
# Tokenize
|
||||
@@ -75,9 +88,9 @@ class TestESM2(unittest.TestCase):
|
||||
|
||||
# Forward pass
|
||||
mlx_outputs = self.mlx_model(
|
||||
mlx_tokens,
|
||||
repr_layers=[self.mlx_model.num_layers],
|
||||
need_head_weights=True
|
||||
mlx_tokens,
|
||||
repr_layers=[self.mlx_model.num_layers],
|
||||
need_head_weights=True,
|
||||
)
|
||||
hf_outputs = self.hf_model(input_ids=hf_tokens)
|
||||
|
||||
@@ -90,12 +103,19 @@ class TestESM2(unittest.TestCase):
|
||||
final_layer = self.mlx_model.num_layers
|
||||
mlx_hidden_states = np.array(mlx_outputs["representations"][final_layer])
|
||||
hf_hidden_states = hf_outputs["hidden_states"][-1].detach().cpu().numpy()
|
||||
self.assertTrue(np.allclose(mlx_hidden_states, hf_hidden_states, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(
|
||||
np.allclose(mlx_hidden_states, hf_hidden_states, rtol=1e-4, atol=1e-4)
|
||||
)
|
||||
|
||||
# Compare attentions for final layer
|
||||
mlx_attentions = np.array(mlx_outputs["attentions"][:, final_layer-1, :, :, :])
|
||||
mlx_attentions = np.array(
|
||||
mlx_outputs["attentions"][:, final_layer - 1, :, :, :]
|
||||
)
|
||||
hf_attentions = hf_outputs["attentions"][-1].detach().cpu().numpy()
|
||||
self.assertTrue(np.allclose(mlx_attentions, hf_attentions, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(
|
||||
np.allclose(mlx_attentions, hf_attentions, rtol=1e-4, atol=1e-4)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user