mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
35 lines
1.2 KiB
Python
35 lines
1.2 KiB
Python
![]() |
from typing import List
|
||
|
|
||
|
import model
|
||
|
import numpy as np
|
||
|
from transformers import AutoModel, AutoTokenizer
|
||
|
|
||
|
|
||
|
def run_torch(bert_model: str, batch: List[str]):
|
||
|
tokenizer = AutoTokenizer.from_pretrained(bert_model)
|
||
|
torch_model = AutoModel.from_pretrained(bert_model)
|
||
|
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
|
||
|
torch_forward = torch_model(**torch_tokens)
|
||
|
torch_output = torch_forward.last_hidden_state.detach().numpy()
|
||
|
torch_pooled = torch_forward.pooler_output.detach().numpy()
|
||
|
return torch_output, torch_pooled
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
bert_model = "bert-base-uncased"
|
||
|
mlx_model = "weights/bert-base-uncased.npz"
|
||
|
batch = [
|
||
|
"This is an example of BERT working in MLX.",
|
||
|
"A second string",
|
||
|
"This is another string.",
|
||
|
]
|
||
|
torch_output, torch_pooled = run_torch(bert_model, batch)
|
||
|
mlx_output, mlx_pooled = model.run(bert_model, mlx_model, batch)
|
||
|
assert np.allclose(
|
||
|
torch_output, mlx_output, rtol=1e-4, atol=1e-5
|
||
|
), "Model output is different"
|
||
|
assert np.allclose(
|
||
|
torch_pooled, mlx_pooled, rtol=1e-4, atol=1e-5
|
||
|
), "Model pooled output is different"
|
||
|
print("Tests pass :)")
|