mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Some fixes / cleanup for BERT example (#269)
* some fixes/cleaning for bert + test * nit
This commit is contained in:
@@ -2,9 +2,15 @@
|
||||
|
||||
An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) within MLX.
|
||||
|
||||
## Downloading and Converting Weights
|
||||
## Setup
|
||||
|
||||
The `convert.py` script relies on `transformers` to download the weights, and exports them as a single `.npz` file.
|
||||
Install the requirements:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Then convert the weights with:
|
||||
|
||||
```
|
||||
python convert.py \
|
||||
@@ -30,49 +36,20 @@ tokens = {key: mx.array(v) for key, v in tokens.items()}
|
||||
output, pooled = model(**tokens)
|
||||
```
|
||||
|
||||
The `output` contains a `Batch x Tokens x Dims` tensor, representing a vector for every input token.
|
||||
If you want to train anything at a **token-level**, you'll want to use this.
|
||||
The `output` contains a `Batch x Tokens x Dims` tensor, representing a vector
|
||||
for every input token. If you want to train anything at a **token-level**,
|
||||
you'll want to use this.
|
||||
|
||||
The `pooled` contains a `Batch x Dims` tensor, which is the pooled representation for each input.
|
||||
If you want to train a **classification** model, you'll want to use this.
|
||||
The `pooled` contains a `Batch x Dims` tensor, which is the pooled
|
||||
representation for each input. If you want to train a **classification**
|
||||
model, you'll want to use this.
|
||||
|
||||
## Comparison with 🤗 `transformers` Implementation
|
||||
|
||||
In order to run the model, and have it forward inference on a batch of examples:
|
||||
## Test
|
||||
|
||||
You can check the output for the default model (`bert-base-uncased`) matches the
|
||||
Hugging Face version with:
|
||||
|
||||
```sh
|
||||
python model.py \
|
||||
--bert-model bert-base-uncased \
|
||||
--mlx-model weights/bert-base-uncased.npz
|
||||
```
|
||||
|
||||
Which will show the following outputs:
|
||||
```
|
||||
MLX BERT:
|
||||
[[[-0.52508914 -0.1993871 -0.28210318 ... -0.61125606 0.19114694
|
||||
0.8227601 ]
|
||||
[-0.8783862 -0.37107834 -0.52238125 ... -0.5067165 1.0847603
|
||||
0.31066895]
|
||||
[-0.70010054 -0.5424497 -0.26593682 ... -0.2688697 0.38338926
|
||||
0.6557663 ]
|
||||
...
|
||||
```
|
||||
|
||||
They can be compared against the 🤗 implementation with:
|
||||
|
||||
```sh
|
||||
python hf_model.py \
|
||||
--bert-model bert-base-uncased
|
||||
```
|
||||
|
||||
Which will show:
|
||||
```
|
||||
HF BERT:
|
||||
[[[-0.52508944 -0.1993877 -0.28210333 ... -0.6112575 0.19114678
|
||||
0.8227603 ]
|
||||
[-0.878387 -0.371079 -0.522381 ... -0.50671494 1.0847601
|
||||
0.31066933]
|
||||
[-0.7001008 -0.5424504 -0.26593733 ... -0.26887015 0.38339025
|
||||
0.65576553]
|
||||
...
|
||||
python test.py
|
||||
```
|
||||
|
Reference in New Issue
Block a user