mlx-examples/bert/README.md

79 lines
2.1 KiB
Markdown
Raw Normal View History

2023-12-09 23:41:15 +08:00
# BERT
2023-12-08 18:14:11 +08:00
2023-12-09 23:41:15 +08:00
An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) within MLX.
2023-12-08 18:14:11 +08:00
2023-12-08 23:20:50 +08:00
## Downloading and Converting Weights
2023-12-08 18:14:11 +08:00
2023-12-08 23:20:50 +08:00
The `convert.py` script relies on `transformers` to download the weights, and exports them as a single `.npz` file.
2023-12-08 18:14:11 +08:00
```
python convert.py \
--bert-model bert-base-uncased \
2023-12-08 18:14:11 +08:00
--mlx-model weights/bert-base-uncased.npz
```
2023-12-09 23:48:34 +08:00
## Usage
To use the `Bert` model in your own code, you can load it with:
```python
from model import Bert, load_model
model, tokenizer = load_model(
"bert-base-uncased",
"weights/bert-base-uncased.npz")
batch = ["This is an example of BERT working on MLX."]
tokens = tokenizer(batch, return_tensors="np", padding=True)
tokens = {key: mx.array(v) for key, v in tokens.items()}
output, pooled = model(**tokens)
```
The `output` contains a `Batch x Tokens x Dims` tensor, representing a vector for every input token.
If you want to train anything at a **token-level**, you'll want to use this.
The `pooled` contains a `Batch x Dims` tensor, which is the pooled representation for each input.
If you want to train a **classification** model, you'll want to use this.
## Comparison with 🤗 `transformers` Implementation
2023-12-08 18:14:11 +08:00
2023-12-08 23:20:50 +08:00
In order to run the model, and have it forward inference on a batch of examples:
2023-12-08 18:14:11 +08:00
```sh
python model.py \
--bert-model bert-base-uncased \
--mlx-model weights/bert-base-uncased.npz
```
Which will show the following outputs:
```
MLX BERT:
[[[-0.52508914 -0.1993871 -0.28210318 ... -0.61125606 0.19114694
0.8227601 ]
[-0.8783862 -0.37107834 -0.52238125 ... -0.5067165 1.0847603
0.31066895]
[-0.70010054 -0.5424497 -0.26593682 ... -0.2688697 0.38338926
0.6557663 ]
...
2023-12-08 18:14:11 +08:00
```
They can be compared against the 🤗 implementation with:
```sh
python hf_model.py \
--bert-model bert-base-uncased
```
Which will show:
```
HF BERT:
[[[-0.52508944 -0.1993877 -0.28210333 ... -0.6112575 0.19114678
0.8227603 ]
[-0.878387 -0.371079 -0.522381 ... -0.50671494 1.0847601
0.31066933]
[-0.7001008 -0.5424504 -0.26593733 ... -0.26887015 0.38339025
0.65576553]
...
2023-12-08 18:14:11 +08:00
```