mlx-examples/bert
dmdaksh 7d7e236061
- Removed unused Python imports (#683)
- bert/model.py:10: tree_unflatten
  - bert/model.py:2: dataclass
  - bert/model.py:8: numpy
  - cifar/resnet.py:6: Any
  - clip/model.py:15: tree_flatten
  - clip/model.py:9: Union
  - gcn/main.py:8: download_cora
  - gcn/main.py:9: cross_entropy
  - llms/gguf_llm/models.py:12: tree_flatten, tree_unflatten
  - llms/gguf_llm/models.py:9: numpy
  - llms/mixtral/mixtral.py:12: tree_map
  - llms/mlx_lm/models/dbrx.py:2: Dict, Union
  - llms/mlx_lm/tuner/trainer.py:5: partial
  - llms/speculative_decoding/decoder.py:1: dataclass, field
  - llms/speculative_decoding/decoder.py:2: Optional
  - llms/speculative_decoding/decoder.py:5: mlx.nn
  - llms/speculative_decoding/decoder.py:6: numpy
  - llms/speculative_decoding/main.py:2: glob
  - llms/speculative_decoding/main.py:3: json
  - llms/speculative_decoding/main.py:5: Path
  - llms/speculative_decoding/main.py:8: mlx.nn
  - llms/speculative_decoding/model.py:6: tree_unflatten
  - llms/speculative_decoding/model.py:7: AutoTokenizer
  - llms/tests/test_lora.py:13: yaml_loader
  - lora/lora.py:14: tree_unflatten
  - lora/models.py:11: numpy
  - lora/models.py:3: glob
  - speechcommands/kwt.py:1: Any
  - speechcommands/main.py:7: mlx.data
  - stable_diffusion/stable_diffusion/model_io.py:4: partial
  - whisper/benchmark.py:5: sys
  - whisper/test.py:5: subprocess
  - whisper/whisper/audio.py:6: Optional
  - whisper/whisper/decoding.py:8: mlx.nn
2024-04-16 07:50:32 -07:00
..
weights BERT implementation 2023-12-08 05:14:11 -05:00
convert.py Enable more BERT models (#580) 2024-03-19 17:21:33 -07:00
model.py - Removed unused Python imports (#683) 2024-04-16 07:50:32 -07:00
README.md Enable more BERT models (#580) 2024-03-19 17:21:33 -07:00
requirements.txt Update to mlx>=0.0.5 2023-12-13 17:48:07 -05:00
test.py Enable more BERT models (#580) 2024-03-19 17:21:33 -07:00

BERT

An implementation of BERT (Devlin, et al., 2019) in MLX.

Setup

Install the requirements:

pip install -r requirements.txt

Then convert the weights with:

python convert.py \
    --bert-model bert-base-uncased \
    --mlx-model weights/bert-base-uncased.npz

Usage

To use the Bert model in your own code, you can load it with:

import mlx.core as mx
from model import Bert, load_model

model, tokenizer = load_model(
    "bert-base-uncased",
    "weights/bert-base-uncased.npz")

batch = ["This is an example of BERT working on MLX."]
tokens = tokenizer(batch, return_tensors="np", padding=True)
tokens = {key: mx.array(v) for key, v in tokens.items()}

output, pooled = model(**tokens)

The output contains a Batch x Tokens x Dims tensor, representing a vector for every input token. If you want to train anything at the token-level, use this.

The pooled contains a Batch x Dims tensor, which is the pooled representation for each input. If you want to train a classification model, use this.

Test

You can check the output for the default model (bert-base-uncased) matches the Hugging Face version with:

python test.py