mlx-examples/bert/README.md

57 lines
1.2 KiB
Markdown
Raw Normal View History

2023-12-09 23:41:15 +08:00
# BERT
2023-12-08 18:14:11 +08:00
An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) in MLX.
2023-12-08 18:14:11 +08:00
## Setup
2023-12-08 18:14:11 +08:00
Install the requirements:
```
pip install -r requirements.txt
```
Then convert the weights with:
2023-12-08 18:14:11 +08:00
```
python convert.py \
--bert-model bert-base-uncased \
2023-12-08 18:14:11 +08:00
--mlx-model weights/bert-base-uncased.npz
```
2023-12-09 23:48:34 +08:00
## Usage
To use the `Bert` model in your own code, you can load it with:
```python
import mlx.core as mx
2023-12-09 23:48:34 +08:00
from model import Bert, load_model
model, tokenizer = load_model(
"bert-base-uncased",
"weights/bert-base-uncased.npz")
batch = ["This is an example of BERT working on MLX."]
tokens = tokenizer(batch, return_tensors="np", padding=True)
tokens = {key: mx.array(v) for key, v in tokens.items()}
output, pooled = model(**tokens)
```
The `output` contains a `Batch x Tokens x Dims` tensor, representing a vector
for every input token. If you want to train anything at the **token-level**,
use this.
2023-12-09 23:48:34 +08:00
The `pooled` contains a `Batch x Dims` tensor, which is the pooled
representation for each input. If you want to train a **classification**
model, use this.
2023-12-08 18:14:11 +08:00
## Test
2023-12-08 18:14:11 +08:00
You can check the output for the default model (`bert-base-uncased`) matches the
Hugging Face version with:
2023-12-08 18:14:11 +08:00
```
python test.py
2023-12-08 18:14:11 +08:00
```