mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-11-03 21:18:10 +08:00
BERT implementation
This commit is contained in:
36
bert/hf_model.py
Normal file
36
bert/hf_model.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
def run(bert_model: str):
|
||||
batch = [
|
||||
"This is an example of BERT working on MLX.",
|
||||
"A second string",
|
||||
"This is another string.",
|
||||
]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_model)
|
||||
torch_model = AutoModel.from_pretrained(bert_model)
|
||||
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
|
||||
torch_forward = torch_model(**torch_tokens)
|
||||
torch_output = torch_forward.last_hidden_state.detach().numpy()
|
||||
torch_pooled = torch_forward.pooler_output.detach().numpy()
|
||||
|
||||
print("\n HF BERT:")
|
||||
print(torch_output)
|
||||
print("\n\n HF Pooled:")
|
||||
print(torch_pooled[0, :20])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--bert-model",
|
||||
type=str,
|
||||
default="bert-base-uncased",
|
||||
help="The huggingface name of the BERT model to save.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
run(args.bert_model)
|
||||
Reference in New Issue
Block a user