diff --git a/bert/README.md b/bert/README.md index bb856ed3..29628cba 100644 --- a/bert/README.md +++ b/bert/README.md @@ -12,7 +12,31 @@ python convert.py \ --mlx-model weights/bert-base-uncased.npz ``` -## Run the Model +## Usage + +To use the `Bert` model in your own code, you can load it with: + +```python +from model import Bert, load_model + +model, tokenizer = load_model( + "bert-base-uncased", + "weights/bert-base-uncased.npz") + +batch = ["This is an example of BERT working on MLX."] +tokens = tokenizer(batch, return_tensors="np", padding=True) +tokens = {key: mx.array(v) for key, v in tokens.items()} + +output, pooled = model(**tokens) +``` + +The `output` contains a `Batch x Tokens x Dims` tensor, representing a vector for every input token. +If you want to train anything at a **token-level**, you'll want to use this. + +The `pooled` contains a `Batch x Dims` tensor, which is the pooled representation for each input. +If you want to train a **classification** model, you'll want to use this. + +## Comparison with 🤗 `transformers` Implementation In order to run the model, and have it forward inference on a batch of examples: