2023-12-08 18:14:11 +08:00
|
|
|
import argparse
|
|
|
|
|
2023-12-21 02:22:25 +08:00
|
|
|
from transformers import AutoModel, AutoTokenizer
|
|
|
|
|
2023-12-08 18:14:11 +08:00
|
|
|
|
|
|
|
def run(bert_model: str):
|
|
|
|
batch = [
|
|
|
|
"This is an example of BERT working on MLX.",
|
|
|
|
"A second string",
|
|
|
|
"This is another string.",
|
|
|
|
]
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(bert_model)
|
|
|
|
torch_model = AutoModel.from_pretrained(bert_model)
|
|
|
|
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
|
|
|
|
torch_forward = torch_model(**torch_tokens)
|
|
|
|
torch_output = torch_forward.last_hidden_state.detach().numpy()
|
|
|
|
torch_pooled = torch_forward.pooler_output.detach().numpy()
|
|
|
|
|
|
|
|
print("\n HF BERT:")
|
|
|
|
print(torch_output)
|
|
|
|
print("\n\n HF Pooled:")
|
|
|
|
print(torch_pooled[0, :20])
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2023-12-10 06:15:25 +08:00
|
|
|
parser = argparse.ArgumentParser(
|
2023-12-13 09:08:04 +08:00
|
|
|
description="Run the BERT model using Hugging Face Transformers."
|
2023-12-10 06:15:25 +08:00
|
|
|
)
|
2023-12-08 18:14:11 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--bert-model",
|
2023-12-10 06:15:25 +08:00
|
|
|
choices=[
|
|
|
|
"bert-base-uncased",
|
|
|
|
"bert-base-cased",
|
|
|
|
"bert-large-uncased",
|
|
|
|
"bert-large-cased",
|
|
|
|
],
|
2023-12-08 18:14:11 +08:00
|
|
|
default="bert-base-uncased",
|
|
|
|
help="The huggingface name of the BERT model to save.",
|
|
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
run(args.bert_model)
|