mlx-examples/bert/hf_model.py

44 lines
1.2 KiB
Python
Raw Normal View History

2023-12-08 18:14:11 +08:00
from transformers import AutoModel, AutoTokenizer
import argparse
def run(bert_model: str):
batch = [
"This is an example of BERT working on MLX.",
"A second string",
"This is another string.",
]
tokenizer = AutoTokenizer.from_pretrained(bert_model)
torch_model = AutoModel.from_pretrained(bert_model)
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
torch_forward = torch_model(**torch_tokens)
torch_output = torch_forward.last_hidden_state.detach().numpy()
torch_pooled = torch_forward.pooler_output.detach().numpy()
print("\n HF BERT:")
print(torch_output)
print("\n\n HF Pooled:")
print(torch_pooled[0, :20])
if __name__ == "__main__":
2023-12-10 06:15:25 +08:00
parser = argparse.ArgumentParser(
2023-12-13 09:08:04 +08:00
description="Run the BERT model using Hugging Face Transformers."
2023-12-10 06:15:25 +08:00
)
2023-12-08 18:14:11 +08:00
parser.add_argument(
"--bert-model",
2023-12-10 06:15:25 +08:00
choices=[
"bert-base-uncased",
"bert-base-cased",
"bert-large-uncased",
"bert-large-cased",
],
2023-12-08 18:14:11 +08:00
default="bert-base-uncased",
help="The huggingface name of the BERT model to save.",
)
args = parser.parse_args()
run(args.bert_model)