From 98d800866e62cdfb99cb71382f9e85b0a2684d87 Mon Sep 17 00:00:00 2001 From: Vincent Amato Date: Fri, 15 Aug 2025 23:54:53 -0400 Subject: [PATCH] Formatted code --- esm/benchmarks/benchmark_mx.py | 2 +- esm/benchmarks/benchmark_pt.py | 10 +++--- esm/convert.py | 61 +++++++++++++++++----------------- esm/esm/__init__.py | 20 +++++------ esm/esm/attention.py | 43 +++++++++++++----------- esm/esm/model.py | 24 ++++++++----- esm/esm/rotary_embedding.py | 6 ++-- esm/esm/tokenizer.py | 50 ++++++++++++++++++++++------ esm/main.py | 55 +++++++++++++++++++----------- esm/test.py | 56 +++++++++++++++++++++---------- 10 files changed, 200 insertions(+), 127 deletions(-) diff --git a/esm/benchmarks/benchmark_mx.py b/esm/benchmarks/benchmark_mx.py index 7f9740b7..c970d34e 100644 --- a/esm/benchmarks/benchmark_mx.py +++ b/esm/benchmarks/benchmark_mx.py @@ -44,4 +44,4 @@ throughput = batch_size * 1000 / ms_per_step # Display results print(f"Time (ms) per step: {ms_per_step:.3f}") -print(f"Throughput: {throughput:.2f} sequences/sec") \ No newline at end of file +print(f"Throughput: {throughput:.2f} sequences/sec") diff --git a/esm/benchmarks/benchmark_pt.py b/esm/benchmarks/benchmark_pt.py index 96a8e2f0..44f28e5b 100644 --- a/esm/benchmarks/benchmark_pt.py +++ b/esm/benchmarks/benchmark_pt.py @@ -22,11 +22,11 @@ steps = 50 # Tokenize input sequence and replicate for the batch # Replicate the same sequence 'batch_size' times to create a batch inputs = tokenizer( - [protein_sequence] * batch_size, - return_tensors="pt", - padding=True, + [protein_sequence] * batch_size, + return_tensors="pt", + padding=True, truncation=True, - max_length=1024 + max_length=1024, ) input_ids = inputs["input_ids"].to("mps") attention_mask = inputs["attention_mask"].to("mps") @@ -49,4 +49,4 @@ throughput = batch_size * 1000 / ms_per_step # Report results print(f"Time (ms) per step: {ms_per_step:.3f}") -print(f"Throughput: {throughput:.2f} sequences/sec") \ No newline at end of file +print(f"Throughput: {throughput:.2f} sequences/sec") diff --git a/esm/convert.py b/esm/convert.py index 9e1c2291..bfbc4be7 100644 --- a/esm/convert.py +++ b/esm/convert.py @@ -8,6 +8,7 @@ import mlx.core as mx import torch from huggingface_hub import snapshot_download + def download(hf_repo: str) -> Path: """Download model from Hugging Face.""" return Path( @@ -17,13 +18,14 @@ def download(hf_repo: str) -> Path: ) ) + def remap_key(key: str) -> str: """Remap HuggingFace ESM key names to MLX format.""" # Skip position embeddings and position_ids if "position_embeddings" in key or "position_ids" in key: return None - + # Map lm_head components properly if key == "lm_head.decoder.weight": return "lm_head.weight" @@ -37,14 +39,14 @@ def remap_key(key: str) -> str: return "lm_head.layer_norm.weight" if key == "lm_head.layer_norm.bias": return "lm_head.layer_norm.bias" - + # Core remapping patterns key = key.replace("esm.embeddings.word_embeddings", "embed_tokens") key = key.replace("esm.encoder.emb_layer_norm_after", "emb_layer_norm_after") key = key.replace("esm.encoder.layer.", "layer_") key = key.replace("esm.contact_head", "contact_head") key = key.replace("lm_head", "lm_head") - + # Attention patterns key = key.replace(".attention.self.", ".self_attn.") key = key.replace(".attention.output.dense", ".self_attn.out_proj") @@ -53,12 +55,12 @@ def remap_key(key: str) -> str: key = key.replace(".key", ".k_proj") key = key.replace(".value", ".v_proj") key = key.replace(".rotary_embeddings", ".rot_emb") - + # FFN patterns key = key.replace(".intermediate.dense", ".fc1") key = key.replace(".output.dense", ".fc2") key = key.replace(".LayerNorm", ".final_layer_norm") - + return key @@ -70,24 +72,24 @@ def load_weights(model_path: Path) -> Dict: if safetensors_path.exists(): print("Loading from safetensors...") return mx.load(str(safetensors_path)) - + # Check for single bin file single_bin_path = model_path / "pytorch_model.bin" if single_bin_path.exists(): print("Loading from pytorch_model.bin...") state_dict = torch.load(str(single_bin_path), map_location="cpu") return {k: v.numpy() for k, v in state_dict.items()} - + # Check for sharded bin files index_file = model_path / "pytorch_model.bin.index.json" if index_file.exists(): print("Loading from sharded bin files...") with open(index_file) as f: index = json.load(f) - + # Get unique shard files shard_files = set(index["weight_map"].values()) - + # Load all shards state_dict = {} for shard_file in sorted(shard_files): @@ -95,9 +97,9 @@ def load_weights(model_path: Path) -> Dict: shard_path = model_path / shard_file shard_dict = torch.load(str(shard_path), map_location="cpu") state_dict.update(shard_dict) - + return {k: v.numpy() for k, v in state_dict.items()} - + raise ValueError(f"No model weights found in {model_path}") @@ -106,38 +108,34 @@ def convert(model_path: Path) -> Dict[str, mx.array]: # Load weights from any format weights = load_weights(model_path) - + # Convert keys and create MLX arrays mlx_weights = {} for key, value in weights.items(): mlx_key = remap_key(key) if mlx_key is not None: - mlx_weights[mlx_key] = mx.array(value) if not isinstance(value, mx.array) else value - + mlx_weights[mlx_key] = ( + mx.array(value) if not isinstance(value, mx.array) else value + ) + # If lm_head.weight is missing but embed_tokens.weight exists, set up weight sharing # (This is for smaller models that don't have a separate lm_head.decoder.weight) if "lm_head.weight" not in mlx_weights and "embed_tokens.weight" in mlx_weights: mlx_weights["lm_head.weight"] = mlx_weights["embed_tokens.weight"] - + return mlx_weights def main(): parser = argparse.ArgumentParser(description="Convert ESM weights to MLX format") parser.add_argument( - "--hf-path", - default="facebook/esm2_t6_8M_UR50D", - help="Hugging Face model path" - ) - parser.add_argument( - "--mlx-path", - default=None, - help="Output path for MLX model" + "--hf-path", default="facebook/esm2_t6_8M_UR50D", help="Hugging Face model path" ) + parser.add_argument("--mlx-path", default=None, help="Output path for MLX model") parser.add_argument( "--checkpoints-dir", default="checkpoints", - help="Directory to save checkpoints (default: checkpoints)" + help="Directory to save checkpoints (default: checkpoints)", ) args = parser.parse_args() @@ -145,7 +143,7 @@ def main(): # Download model print(f"Downloading {args.hf_path}...") model_path = download(args.hf_path) - + # Set output path if args.mlx_path is None: model_name = args.hf_path.split("/")[-1] @@ -154,25 +152,26 @@ def main(): args.mlx_path = checkpoints_dir / f"mlx-{model_name}" mlx_path = Path(args.mlx_path) mlx_path.mkdir(parents=True, exist_ok=True) - + # Convert weights print("Converting weights...") mlx_weights = convert(model_path) - + # Save weights print(f"Saving MLX weights to {mlx_path}...") mx.save_safetensors(str(mlx_path / "model.safetensors"), mlx_weights) - + # Copy config and other files print("Copying config...") shutil.copy(model_path / "config.json", mlx_path / "config.json") - + for file_name in ["special_tokens_map.json", "tokenizer.json", "vocab.txt"]: src_file = model_path / file_name if src_file.exists(): shutil.copy(src_file, mlx_path / file_name) - + print(f"Conversion complete! MLX model saved to {mlx_path}") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/esm/esm/__init__.py b/esm/esm/__init__.py index 77ceeeee..1c2139d4 100644 --- a/esm/esm/__init__.py +++ b/esm/esm/__init__.py @@ -2,18 +2,18 @@ ESM-2 protein language model implementation in MLX """ -from .model import ESM2 -from .tokenizer import ProteinTokenizer -from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer from .attention import MultiheadAttention +from .model import ESM2 +from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer from .rotary_embedding import RotaryEmbedding +from .tokenizer import ProteinTokenizer __all__ = [ - 'ESM2', - 'ProteinTokenizer', - 'ContactPredictionHead', - 'RobertaLMHead', - 'TransformerLayer', - 'MultiheadAttention', - 'RotaryEmbedding' + "ESM2", + "ProteinTokenizer", + "ContactPredictionHead", + "RobertaLMHead", + "TransformerLayer", + "MultiheadAttention", + "RotaryEmbedding", ] diff --git a/esm/esm/attention.py b/esm/esm/attention.py index de5b6f9d..be535b25 100644 --- a/esm/esm/attention.py +++ b/esm/esm/attention.py @@ -5,14 +5,15 @@ import mlx.nn as nn from .rotary_embedding import RotaryEmbedding + class MultiheadAttention(nn.Module): """ Multi-head attention layer with rotary position embeddings, as used in ESM-2. - This module implements both self-attention (when `key` and `value` are not - provided) and cross-attention. It projects input sequences into queries, - keys, and values, applies rotary position embeddings to encode relative - position information, computes scaled dot-product attention over multiple + This module implements both self-attention (when `key` and `value` are not + provided) and cross-attention. It projects input sequences into queries, + keys, and values, applies rotary position embeddings to encode relative + position information, computes scaled dot-product attention over multiple heads in parallel, and returns a combined output projection. Args: @@ -56,7 +57,7 @@ class MultiheadAttention(nn.Module): ) -> Tuple[mx.array, Optional[mx.array]]: """ Multi-head attention forward pass. - + Args: query: Tensor of shape (tgt_len, batch, embed_dim). key: Optional tensor of shape (src_len, batch, embed_dim). Defaults to `query`. @@ -67,32 +68,32 @@ class MultiheadAttention(nn.Module): Returns: attn_output: Tensor of shape (tgt_len, batch, embed_dim). - attn_weights_out: Attention weights of shape + attn_weights_out: Attention weights of shape (num_heads, batch, tgt_len, src_len) if per-head, or (batch, tgt_len, src_len) if averaged. """ - + tgt_len, bsz, embed_dim = query.shape assert embed_dim == self.embed_dim - + # For self-attention, use query as key and value if not provided if key is None: key = query if value is None: value = query - + # Project queries, keys, values q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) - + q = q * self.scaling # Reshape for multi-head attention q = q.reshape(tgt_len, bsz * self.num_heads, self.head_dim).swapaxes(0, 1) k = k.reshape(-1, bsz * self.num_heads, self.head_dim).swapaxes(0, 1) v = v.reshape(-1, bsz * self.num_heads, self.head_dim).swapaxes(0, 1) - + src_len = k.shape[1] # Apply rotary embeddings if present @@ -114,7 +115,9 @@ class MultiheadAttention(nn.Module): attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len) # Convert key_padding_mask to boolean and expand dimensions # key_padding_mask: [bsz, src_len] -> [bsz, 1, 1, src_len] - mask = mx.expand_dims(mx.expand_dims(key_padding_mask.astype(mx.bool_), 1), 2) + mask = mx.expand_dims( + mx.expand_dims(key_padding_mask.astype(mx.bool_), 1), 2 + ) # Apply mask: set attention to -inf where mask is True (padded positions) attn_weights = mx.where(mask, -mx.inf, attn_weights) attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len) @@ -126,25 +129,25 @@ class MultiheadAttention(nn.Module): # Compute attention output attn = attn_probs @ v assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] - + # Reshape output attn = attn.swapaxes(0, 1).reshape(tgt_len, bsz, embed_dim) attn = self.out_proj(attn) - + # Return attention weights if requested attn_weights_out: Optional[mx.array] = None if need_head_weights: # Return attention weights for each head separately - attn_weights_out = attn_weights_float.reshape( - bsz, self.num_heads, tgt_len, src_len - ).astype(attn.dtype).swapaxes(0, 1) + attn_weights_out = ( + attn_weights_float.reshape(bsz, self.num_heads, tgt_len, src_len) + .astype(attn.dtype) + .swapaxes(0, 1) + ) else: # Return averaged attention weights attn_weights_out = mx.mean( attn_weights_float.reshape(bsz, self.num_heads, tgt_len, src_len), - axis=1 + axis=1, ).astype(attn.dtype) return attn, attn_weights_out - - diff --git a/esm/esm/model.py b/esm/esm/model.py index 45ce9902..69e28daa 100644 --- a/esm/esm/model.py +++ b/esm/esm/model.py @@ -1,12 +1,12 @@ -from typing import List, Dict, Optional, Tuple -from pathlib import Path import json +from pathlib import Path +from typing import Dict, List, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .tokenizer import ProteinTokenizer from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer +from .tokenizer import ProteinTokenizer class ESM2(nn.Module): @@ -40,7 +40,7 @@ class ESM2(nn.Module): self.tokenizer = tokenizer self.vocab_size = len(tokenizer) - # Special token IDs / config + # Special token IDs / config self.padding_idx = tokenizer.pad_id self.mask_idx = tokenizer.mask_id self.cls_idx = tokenizer.cls_id @@ -128,8 +128,12 @@ class ESM2(nn.Module): mask_ratio_train = 0.15 * 0.8 src_lengths = mx.sum(~padding_mask, axis=-1, keepdims=True) - mask_ratio_observed = mx.sum(mask_positions, axis=-1, keepdims=True) / src_lengths - scale_factor = (1 - mask_ratio_train) / mx.maximum(1 - mask_ratio_observed, 1e-8) + mask_ratio_observed = ( + mx.sum(mask_positions, axis=-1, keepdims=True) / src_lengths + ) + scale_factor = (1 - mask_ratio_train) / mx.maximum( + 1 - mask_ratio_observed, 1e-8 + ) x = x * scale_factor[:, None, :] # Zero out padding positions @@ -194,7 +198,9 @@ class ESM2(nn.Module): # Mask out padded positions if present if padding_mask is not None: attention_mask = 1 - padding_mask.astype(attentions.dtype) - attention_mask = mx.expand_dims(attention_mask, 1) * mx.expand_dims(attention_mask, 2) + attention_mask = mx.expand_dims(attention_mask, 1) * mx.expand_dims( + attention_mask, 2 + ) attentions = attentions * attention_mask[:, None, None, :, :] result["attentions"] = attentions @@ -304,7 +310,7 @@ class ESM2(nn.Module): # Check for vocab and special tokens files vocab_path = model_dir / "vocab.txt" special_tokens_path = model_dir / "special_tokens_map.json" - + if vocab_path.exists() and special_tokens_path.exists(): tokenizer = ProteinTokenizer( vocab_file=str(vocab_path), @@ -333,4 +339,4 @@ class ESM2(nn.Module): cur[parts[-1]] = value model.update(nested_weights) - return tokenizer, model \ No newline at end of file + return tokenizer, model diff --git a/esm/esm/rotary_embedding.py b/esm/esm/rotary_embedding.py index def31d2f..47f458eb 100644 --- a/esm/esm/rotary_embedding.py +++ b/esm/esm/rotary_embedding.py @@ -58,7 +58,9 @@ class RotaryEmbedding(nn.Module): self._cos_cached = None self._sin_cached = None - def _update_cos_sin_tables(self, x: mx.array, seq_dimension: int = 1) -> Tuple[mx.array, mx.array]: + def _update_cos_sin_tables( + self, x: mx.array, seq_dimension: int = 1 + ) -> Tuple[mx.array, mx.array]: """ Compute and cache cos/sin tables for the given sequence length. @@ -109,4 +111,4 @@ class RotaryEmbedding(nn.Module): return ( apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), - ) \ No newline at end of file + ) diff --git a/esm/esm/tokenizer.py b/esm/esm/tokenizer.py index 723fe0e8..7c8caf58 100644 --- a/esm/esm/tokenizer.py +++ b/esm/esm/tokenizer.py @@ -1,14 +1,38 @@ -from typing import List, Sequence, Union, Optional import json from pathlib import Path +from typing import List, Optional, Sequence, Union import mlx.core as mx # Canonical amino-acid tokens (IUPAC standard + uncommon variants) PROTEIN_TOKENS = [ - "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", - "P", "K", "Q", "N", "F", "Y", "M", "H", "W", "C", - "X", "B", "U", "Z", "O", ".", "-" + "L", + "A", + "G", + "V", + "S", + "E", + "R", + "T", + "I", + "D", + "P", + "K", + "Q", + "N", + "F", + "Y", + "M", + "H", + "W", + "C", + "X", + "B", + "U", + "Z", + "O", + ".", + "-", ] ArrayLike = Union[List[int], mx.array] @@ -39,7 +63,7 @@ class ProteinTokenizer: If both files are provided, they override the default vocabulary and special token mappings. Otherwise, defaults are loaded. """ - + # Load vocabulary from files if given, otherwise use built-in defaults if vocab_file and special_tokens_map_file: self._load_from_files(vocab_file, special_tokens_map_file) @@ -51,15 +75,15 @@ class ProteinTokenizer: self.id_to_token = {i: tok for i, tok in enumerate(self.vocab)} # Cache special token IDs - self.cls_id = self.token_to_id[""] - self.pad_id = self.token_to_id[""] - self.eos_id = self.token_to_id[""] - self.unk_id = self.token_to_id[""] + self.cls_id = self.token_to_id[""] + self.pad_id = self.token_to_id[""] + self.eos_id = self.token_to_id[""] + self.unk_id = self.token_to_id[""] self.mask_id = self.token_to_id[""] # Behavior flags for ESM-2-style BOS/EOS self.prepend_bos = True - self.append_eos = True + self.append_eos = True def _load_from_files(self, vocab_file: str, special_tokens_map_file: str) -> None: """Load vocabulary and special tokens from the provided files.""" @@ -201,7 +225,11 @@ class ProteinTokenizer: for i in ids_list: tok = self.id_to_token.get(i, "") if skip_special_tokens and tok in { - "", "", "", "", "" + "", + "", + "", + "", + "", }: continue toks.append(tok) diff --git a/esm/main.py b/esm/main.py index a73cf4b2..9480a03f 100644 --- a/esm/main.py +++ b/esm/main.py @@ -1,58 +1,72 @@ import argparse + import mlx.core as mx from esm import ESM2 + def main(): parser = argparse.ArgumentParser(description="ESM-2 MLX Inference") parser.add_argument( - "--model-path", + "--model-path", default="checkpoints/mlx-esm2_t33_650M_UR50D", - help="Path to MLX model checkpoint" + help="Path to MLX model checkpoint", ) parser.add_argument( - "--sequence", + "--sequence", default="MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN", - help="Protein sequence to test (default: human insulin)" + help="Protein sequence to test (default: human insulin)", ) parser.add_argument( - "--mask-position", - type=int, + "--mask-position", + type=int, default=None, - help="Position to mask (default: middle of sequence)" + help="Position to mask (default: middle of sequence)", ) args = parser.parse_args() - + # Load pretrained ESM-2 model and tokenizer tokenizer, model = ESM2.from_pretrained(args.model_path) - + # Determine sequence and position to mask sequence = args.sequence.upper() - mask_pos = args.mask_position if args.mask_position is not None else len(sequence) // 2 + mask_pos = ( + args.mask_position if args.mask_position is not None else len(sequence) // 2 + ) if mask_pos >= len(sequence): mask_pos = len(sequence) - 1 original_aa = sequence[mask_pos] # The original amino acid at masked position - + # Display input info print(f"Original sequence: {sequence}") print(f"Masked sequence: {sequence[:mask_pos]}{sequence[mask_pos+1:]}") print(f"Predicting position {mask_pos}: {original_aa}\n") - + # Tokenize sequence before and after the mask before = tokenizer.encode(sequence[:mask_pos], add_special_tokens=False) - after = tokenizer.encode(sequence[mask_pos+1:], add_special_tokens=False) - + after = tokenizer.encode(sequence[mask_pos + 1 :], add_special_tokens=False) + # Build token sequence with , , and - tokens = mx.array([[tokenizer.cls_id] + before.tolist() + [tokenizer.mask_id] + after.tolist() + [tokenizer.eos_id]]) + tokens = mx.array( + [ + [tokenizer.cls_id] + + before.tolist() + + [tokenizer.mask_id] + + after.tolist() + + [tokenizer.eos_id] + ] + ) mask_token_pos = 1 + len(before) # Position of token - + # Run model to get logits for each token position logits = model(tokens)["logits"] - probs = mx.softmax(logits[0, mask_token_pos, :]) # Softmax over vocabulary at mask position - + probs = mx.softmax( + logits[0, mask_token_pos, :] + ) # Softmax over vocabulary at mask position + # Get top-5 most likely tokens top_indices = mx.argsort(probs)[-5:][::-1] - + # Print predictions print("Top predictions:") for i, idx in enumerate(top_indices): @@ -62,5 +76,6 @@ def main(): marker = "✓" if token == original_aa else " " print(f"{marker} {i+1}. {token}: {prob:.3f} ({prob*100:.1f}%)") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/esm/test.py b/esm/test.py index a784e453..e9fb3bd8 100644 --- a/esm/test.py +++ b/esm/test.py @@ -1,6 +1,7 @@ import unittest + import numpy as np -from transformers import AutoTokenizer, EsmForMaskedLM, EsmConfig +from transformers import AutoTokenizer, EsmConfig, EsmForMaskedLM from esm import ESM2 @@ -8,65 +9,77 @@ from esm import ESM2 MLX_PATH = "checkpoints/mlx-esm2_t12_35M_UR50D" HF_PATH = "facebook/esm2_t12_35M_UR50D" + def load_mlx_model(): """Load MLX ESM-2 model and tokenizer.""" tokenizer, model = ESM2.from_pretrained(MLX_PATH) return tokenizer, model + def load_hf_model(): """Load Hugging Face ESM-2 model and tokenizer with hidden states + attentions.""" tokenizer = AutoTokenizer.from_pretrained(HF_PATH) config = EsmConfig.from_pretrained( - HF_PATH, - output_hidden_states=True, - output_attentions=True + HF_PATH, output_hidden_states=True, output_attentions=True ) model = EsmForMaskedLM.from_pretrained(HF_PATH, config=config) return tokenizer, model + class TestESM2(unittest.TestCase): @classmethod def setUpClass(cls): # Load both MLX and HF models/tokenizers once for all tests cls.mlx_tokenizer, cls.mlx_model = load_mlx_model() cls.hf_tokenizer, cls.hf_model = load_hf_model() - + def test_tokenizer(self): """Verify MLX tokenizer matches Hugging Face tokenizer behavior.""" self.assertEqual(len(self.mlx_tokenizer), len(self.hf_tokenizer)) sequences = [ "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK", - "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN" + "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN", ] # Compare batched tokenization (padded sequences) mlx_batch = self.mlx_tokenizer.batch_encode(sequences) - hf_batch = self.hf_tokenizer(sequences, return_tensors="pt", padding=True)["input_ids"].cpu().numpy() + hf_batch = ( + self.hf_tokenizer(sequences, return_tensors="pt", padding=True)["input_ids"] + .cpu() + .numpy() + ) self.assertEqual(tuple(mlx_batch.shape), tuple(hf_batch.shape)) self.assertTrue( np.array_equal(np.array(mlx_batch.tolist(), dtype=hf_batch.dtype), hf_batch) ) - + # Compare single-sequence encode/decode for sequence in sequences: mlx_tokens = self.mlx_tokenizer.encode(sequence) - hf_tokens = self.hf_tokenizer(sequence, return_tensors="pt")["input_ids"].cpu().numpy().tolist()[0] + hf_tokens = ( + self.hf_tokenizer(sequence, return_tensors="pt")["input_ids"] + .cpu() + .numpy() + .tolist()[0] + ) self.assertTrue(np.array_equal(mlx_tokens, hf_tokens)) self.assertEqual( self.mlx_tokenizer.decode(mlx_tokens), - self.hf_tokenizer.decode(hf_tokens).replace(" ", "") + self.hf_tokenizer.decode(hf_tokens).replace(" ", ""), ) self.assertEqual( self.mlx_tokenizer.decode(mlx_tokens, skip_special_tokens=True), - self.hf_tokenizer.decode(hf_tokens, skip_special_tokens=True).replace(" ", "") + self.hf_tokenizer.decode(hf_tokens, skip_special_tokens=True).replace( + " ", "" + ), ) def test_model(self): """Verify MLX and HF model outputs match (logits, hidden states, attentions).""" sequences = [ "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK", - "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN" + "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN", ] for sequence in sequences: # Tokenize @@ -75,9 +88,9 @@ class TestESM2(unittest.TestCase): # Forward pass mlx_outputs = self.mlx_model( - mlx_tokens, - repr_layers=[self.mlx_model.num_layers], - need_head_weights=True + mlx_tokens, + repr_layers=[self.mlx_model.num_layers], + need_head_weights=True, ) hf_outputs = self.hf_model(input_ids=hf_tokens) @@ -90,12 +103,19 @@ class TestESM2(unittest.TestCase): final_layer = self.mlx_model.num_layers mlx_hidden_states = np.array(mlx_outputs["representations"][final_layer]) hf_hidden_states = hf_outputs["hidden_states"][-1].detach().cpu().numpy() - self.assertTrue(np.allclose(mlx_hidden_states, hf_hidden_states, rtol=1e-4, atol=1e-4)) + self.assertTrue( + np.allclose(mlx_hidden_states, hf_hidden_states, rtol=1e-4, atol=1e-4) + ) # Compare attentions for final layer - mlx_attentions = np.array(mlx_outputs["attentions"][:, final_layer-1, :, :, :]) + mlx_attentions = np.array( + mlx_outputs["attentions"][:, final_layer - 1, :, :, :] + ) hf_attentions = hf_outputs["attentions"][-1].detach().cpu().numpy() - self.assertTrue(np.allclose(mlx_attentions, hf_attentions, rtol=1e-4, atol=1e-4)) + self.assertTrue( + np.allclose(mlx_attentions, hf_attentions, rtol=1e-4, atol=1e-4) + ) + if __name__ == "__main__": unittest.main()