From 5ab987b43158fc09f757e8dde5a9ea4dd348c371 Mon Sep 17 00:00:00 2001 From: Joe Barrow Date: Fri, 8 Dec 2023 05:14:11 -0500 Subject: [PATCH 1/6] BERT implementation --- bert/README.md | 68 +++++++++++ bert/convert.py | 48 ++++++++ bert/hf_model.py | 36 ++++++ bert/model.py | 246 ++++++++++++++++++++++++++++++++++++++++ bert/weights/.gitignore | 1 + 5 files changed, 399 insertions(+) create mode 100644 bert/README.md create mode 100644 bert/convert.py create mode 100644 bert/hf_model.py create mode 100644 bert/model.py create mode 100644 bert/weights/.gitignore diff --git a/bert/README.md b/bert/README.md new file mode 100644 index 00000000..767e53d7 --- /dev/null +++ b/bert/README.md @@ -0,0 +1,68 @@ +# mlxbert + +A BERT implementation in Apple's new MLX framework. + +## Dependency Installation + +```sh +poetry install --no-root +``` + +If you don't want to do that, simply make sure you have the following dependencies installed: + +- `mlx` +- `transformers` +- `numpy` + +## Download and Convert + +``` +python convert.py \ + --bert-model bert-base-uncased + --mlx-model weights/bert-base-uncased.npz +``` + +## Run the Model + +Right now, this is just a test to show tha the outputs from mlx and huggingface don't change all that much. + +```sh +python model.py \ + --bert-model bert-base-uncased \ + --mlx-model weights/bert-base-uncased.npz +``` + +Which will show the following outputs: +``` +MLX BERT: +[[[-0.17057164 0.08602728 -0.12471077 ... -0.09469379 -0.00275938 + 0.28314582] + [ 0.15222196 -0.48997563 -0.26665813 ... -0.19935863 -0.17162783 + -0.51360303] + [ 0.9460105 0.1358298 -0.2945672 ... 0.00868467 -0.90271163 + -0.2785422 ]]] +``` + +They can be compared against the 🤗 implementation with: + +```sh +python hf_model.py \ + --bert-model bert-base-uncased +``` + +Which will show: +``` + HF BERT: +[[[-0.17057131 0.08602707 -0.12471108 ... -0.09469365 -0.00275959 + 0.28314728] + [ 0.15222463 -0.48997375 -0.26665992 ... -0.19936043 -0.17162988 + -0.5136028 ] + [ 0.946011 0.13582966 -0.29456618 ... 0.00868565 -0.90271175 + -0.27854213]]] +``` + +## To do's + +- [x] fix position encodings +- [x] bert large and cased variants loaded +- [x] example usage \ No newline at end of file diff --git a/bert/convert.py b/bert/convert.py new file mode 100644 index 00000000..d2b7b624 --- /dev/null +++ b/bert/convert.py @@ -0,0 +1,48 @@ +from transformers import BertModel + +import argparse +import numpy + + +def replace_key(key: str) -> str: + key = key.replace(".layer.", ".layers.") + key = key.replace(".self.key.", ".key_proj.") + key = key.replace(".self.query.", ".query_proj.") + key = key.replace(".self.value.", ".value_proj.") + key = key.replace(".attention.output.dense.", ".attention.out_proj.") + key = key.replace(".attention.output.LayerNorm.", ".ln1.") + key = key.replace(".output.LayerNorm.", ".ln2.") + key = key.replace(".intermediate.dense.", ".linear1.") + key = key.replace(".output.dense.", ".linear2.") + key = key.replace(".LayerNorm.", ".norm.") + key = key.replace("pooler.dense.", "pooler.") + return key + + +def convert(bert_model: str, mlx_model: str) -> None: + model = BertModel.from_pretrained(bert_model) + # save the tensors + tensors = { + replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items() + } + numpy.savez(mlx_model, **tensors) + # save the tokenizer + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.") + parser.add_argument( + "--bert-model", + type=str, + default="bert-base-uncased", + help="The huggingface name of the BERT model to save.", + ) + parser.add_argument( + "--mlx-model", + type=str, + default="weights/bert-base-uncased.npz", + help="The output path for the MLX BERT weights.", + ) + args = parser.parse_args() + + convert(args.bert_model, args.mlx_model) \ No newline at end of file diff --git a/bert/hf_model.py b/bert/hf_model.py new file mode 100644 index 00000000..13350e4a --- /dev/null +++ b/bert/hf_model.py @@ -0,0 +1,36 @@ +from transformers import AutoModel, AutoTokenizer + +import argparse + + +def run(bert_model: str): + batch = [ + "This is an example of BERT working on MLX.", + "A second string", + "This is another string.", + ] + + tokenizer = AutoTokenizer.from_pretrained(bert_model) + torch_model = AutoModel.from_pretrained(bert_model) + torch_tokens = tokenizer(batch, return_tensors="pt", padding=True) + torch_forward = torch_model(**torch_tokens) + torch_output = torch_forward.last_hidden_state.detach().numpy() + torch_pooled = torch_forward.pooler_output.detach().numpy() + + print("\n HF BERT:") + print(torch_output) + print("\n\n HF Pooled:") + print(torch_pooled[0, :20]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--bert-model", + type=str, + default="bert-base-uncased", + help="The huggingface name of the BERT model to save.", + ) + args = parser.parse_args() + + run(args.bert_model) diff --git a/bert/model.py b/bert/model.py new file mode 100644 index 00000000..01ac294b --- /dev/null +++ b/bert/model.py @@ -0,0 +1,246 @@ +from typing import Optional +from dataclasses import dataclass +from mlx.utils import tree_unflatten, tree_map +from mlx.nn.layers.base import Module +from mlx.nn.layers.linear import Linear +from mlx.nn.layers.normalization import LayerNorm +from transformers import AutoTokenizer + +import mlx.core as mx +import mlx.nn as nn +import argparse +import numpy +import math + + +@dataclass +class ModelArgs: + intermediate_size: int = 768 + num_attention_heads: int = 12 + num_hidden_layers: int = 12 + vocab_size: int = 30522 + attention_probs_dropout_prob: float = 0.1 + hidden_dropout_prob: float = 0.1 + layer_norm_eps: float = 1e-12 + max_position_embeddings: int = 512 + + +model_configs = { + "bert-base-uncased": ModelArgs(), + "bert-base-cased": ModelArgs(), + "bert-large-uncased": ModelArgs( + intermediate_size=1024, num_attention_heads=16, num_hidden_layers=24 + ), + "bert-large-cased": ModelArgs( + intermediate_size=1024, num_attention_heads=16, num_hidden_layers=24 + ), +} + + +class MultiHeadAttention(Module): + """ + Minor update to the MultiHeadAttention module to ensure that the + projections use bias. + """ + + def __init__( + self, + dims: int, + num_heads: int, + query_input_dims: Optional[int] = None, + key_input_dims: Optional[int] = None, + value_input_dims: Optional[int] = None, + value_dims: Optional[int] = None, + value_output_dims: Optional[int] = None, + ): + super().__init__() + + if (dims % num_heads) != 0: + raise ValueError( + f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0" + ) + + query_input_dims = query_input_dims or dims + key_input_dims = key_input_dims or dims + value_input_dims = value_input_dims or key_input_dims + value_dims = value_dims or dims + value_output_dims = value_output_dims or dims + + self.num_heads = num_heads + self.query_proj = Linear(query_input_dims, dims, True) + self.key_proj = Linear(key_input_dims, dims, True) + self.value_proj = Linear(value_input_dims, value_dims, True) + self.out_proj = Linear(value_dims, value_output_dims, True) + + def __call__(self, queries, keys, values, mask=None): + queries = self.query_proj(queries) + keys = self.key_proj(keys) + values = self.value_proj(values) + + num_heads = self.num_heads + B, L, D = queries.shape + _, S, _ = keys.shape + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) + values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + + # Dimensions are [batch x num heads x sequence x hidden dim] + scale = math.sqrt(1 / queries.shape[-1]) + scores = (queries * scale) @ keys + if mask is not None: + mask = self.converrt_mask_to_additive_causal_mask(mask) + mask = mx.expand_dims(mask, (1, 2)) + mask = mx.broadcast_to(mask, scores.shape) + scores = scores + mask.astype(scores.dtype) + scores = mx.softmax(scores, axis=-1) + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.out_proj(values_hat) + + def converrt_mask_to_additive_causal_mask( + self, mask: mx.array, dtype: mx.Dtype = mx.float32 + ) -> mx.array: + mask = mask == 0 + mask = mask.astype(dtype) * -1e9 + return mask + + +class TransformerEncoderLayer(Module): + """ + A transformer encoder layer with (the original BERT) post-normalization. + """ + + def __init__( + self, + dims: int, + num_heads: int, + mlp_dims: Optional[int] = None, + layer_norm_eps: float = 1e-12, + ): + super().__init__() + mlp_dims = mlp_dims or dims * 4 + self.attention = MultiHeadAttention(dims, num_heads) + self.ln1 = LayerNorm(dims, eps=layer_norm_eps) + self.ln2 = LayerNorm(dims, eps=layer_norm_eps) + self.linear1 = Linear(dims, mlp_dims) + self.linear2 = Linear(mlp_dims, dims) + self.gelu = nn.GELU() + + def __call__(self, x, mask): + attention_out = self.attention(x, x, x, mask) + add_and_norm = self.ln1(x + attention_out) + + ff = self.linear1(add_and_norm) + ff_gelu = self.gelu(ff) + ff_out = self.linear2(ff_gelu) + x = self.ln2(ff_out + add_and_norm) + + return x + + +class TransformerEncoder(Module): + def __init__( + self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None + ): + super().__init__() + self.layers = [ + TransformerEncoderLayer(dims, num_heads, mlp_dims) + for i in range(num_layers) + ] + + def __call__(self, x, mask): + for l in self.layers: + x = l(x, mask) + + return x + + +class BertEmbeddings(nn.Module): + def __init__(self, config: ModelArgs): + self.word_embeddings = nn.Embedding(config.vocab_size, config.intermediate_size) + self.token_type_embeddings = nn.Embedding(2, config.intermediate_size) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.intermediate_size + ) + self.norm = nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps) + + def __call__(self, input_ids: mx.array, token_type_ids: mx.array) -> mx.array: + words = self.word_embeddings(input_ids) + position = self.position_embeddings( + mx.broadcast_to(mx.arange(input_ids.shape[1]), input_ids.shape) + ) + token_types = self.token_type_embeddings(token_type_ids) + + embeddings = position + words + token_types + return self.norm(embeddings) + + +class Bert(nn.Module): + def __init__(self, config: ModelArgs): + self.embeddings = BertEmbeddings(config) + self.encoder = TransformerEncoder( + num_layers=config.num_hidden_layers, + dims=config.intermediate_size, + num_heads=config.num_attention_heads, + ) + self.pooler = nn.Linear(config.intermediate_size, config.vocab_size) + + def __call__( + self, + input_ids: mx.array, + token_type_ids: mx.array, + attention_mask: mx.array | None = None, + ) -> tuple[mx.array, mx.array]: + x = self.embeddings(input_ids, token_type_ids) + y = self.encoder(x, attention_mask) + return y, mx.tanh(self.pooler(y[:, 0])) + + +def run(bert_model: str, mlx_model: str): + batch = [ + "This is an example of BERT working on MLX.", + "A second string", + "This is another string.", + ] + + model = Bert(model_configs[bert_model]) + + weights = mx.load(mlx_model) + weights = tree_unflatten(list(weights.items())) + weights = tree_map(lambda p: mx.array(p), weights) + + model.update(weights) + + tokenizer = AutoTokenizer.from_pretrained(bert_model) + + tokens = tokenizer(batch, return_tensors="np", padding=True) + tokens = {key: mx.array(v) for key, v in tokens.items()} + + mlx_output, mlx_pooled = model(**tokens) + mlx_output = numpy.array(mlx_output) + mlx_pooled = numpy.array(mlx_pooled) + + print("MLX BERT:") + print(mlx_output) + + print("\n\nMLX Pooled:") + print(mlx_pooled[0, :20]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.") + parser.add_argument( + "--bert-model", + type=str, + default="bert-base-uncased", + help="The huggingface name of the BERT model to save.", + ) + parser.add_argument( + "--mlx-model", + type=str, + default="weights/bert-base-uncased.npz", + help="The output path for the MLX BERT weights.", + ) + args = parser.parse_args() + + run(args.bert_model, args.mlx_model) \ No newline at end of file diff --git a/bert/weights/.gitignore b/bert/weights/.gitignore new file mode 100644 index 00000000..44662642 --- /dev/null +++ b/bert/weights/.gitignore @@ -0,0 +1 @@ +*.npz \ No newline at end of file From 7aa8348f6081817d5a86fc191c79a801bf8f03f9 Mon Sep 17 00:00:00 2001 From: Joe Barrow Date: Fri, 8 Dec 2023 10:20:50 -0500 Subject: [PATCH 2/6] Update README for mlx-examples repo --- bert/README.md | 24 ++++-------------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/bert/README.md b/bert/README.md index 767e53d7..e1b7a433 100644 --- a/bert/README.md +++ b/bert/README.md @@ -1,20 +1,10 @@ # mlxbert -A BERT implementation in Apple's new MLX framework. +An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) within mlx. -## Dependency Installation +## Downloading and Converting Weights -```sh -poetry install --no-root -``` - -If you don't want to do that, simply make sure you have the following dependencies installed: - -- `mlx` -- `transformers` -- `numpy` - -## Download and Convert +The `convert.py` script relies on `transformers` to download the weights, and exports them as a single `.npz` file. ``` python convert.py \ @@ -24,7 +14,7 @@ python convert.py \ ## Run the Model -Right now, this is just a test to show tha the outputs from mlx and huggingface don't change all that much. +In order to run the model, and have it forward inference on a batch of examples: ```sh python model.py \ @@ -60,9 +50,3 @@ Which will show: [ 0.946011 0.13582966 -0.29456618 ... 0.00868565 -0.90271175 -0.27854213]]] ``` - -## To do's - -- [x] fix position encodings -- [x] bert large and cased variants loaded -- [x] example usage \ No newline at end of file From d6398681d1e72b408ff912f62057243238fd9d6b Mon Sep 17 00:00:00 2001 From: Joe Barrow Date: Sat, 9 Dec 2023 10:41:15 -0500 Subject: [PATCH 3/6] Cleaning implementation for merge --- bert/README.md | 4 +-- bert/convert.py | 3 +-- bert/hf_model.py | 4 +-- bert/model.py | 66 +++++++++++++++++++++++++----------------------- 4 files changed, 39 insertions(+), 38 deletions(-) diff --git a/bert/README.md b/bert/README.md index e1b7a433..bb856ed3 100644 --- a/bert/README.md +++ b/bert/README.md @@ -1,6 +1,6 @@ -# mlxbert +# BERT -An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) within mlx. +An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) within MLX. ## Downloading and Converting Weights diff --git a/bert/convert.py b/bert/convert.py index d2b7b624..5a9298d6 100644 --- a/bert/convert.py +++ b/bert/convert.py @@ -26,14 +26,13 @@ def convert(bert_model: str, mlx_model: str) -> None: replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items() } numpy.savez(mlx_model, **tensors) - # save the tokenizer if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.") parser.add_argument( "--bert-model", - type=str, + choices=["bert-base-uncased", "bert-base-cased", "bert-large-uncased", "bert-large-cased"], default="bert-base-uncased", help="The huggingface name of the BERT model to save.", ) diff --git a/bert/hf_model.py b/bert/hf_model.py index 13350e4a..9f73028d 100644 --- a/bert/hf_model.py +++ b/bert/hf_model.py @@ -24,10 +24,10 @@ def run(bert_model: str): if __name__ == "__main__": - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(description="Run the BERT model using HuggingFace Transformers.") parser.add_argument( "--bert-model", - type=str, + choices=["bert-base-uncased", "bert-base-cased", "bert-large-uncased", "bert-large-cased"], default="bert-base-uncased", help="The huggingface name of the BERT model to save.", ) diff --git a/bert/model.py b/bert/model.py index 01ac294b..318f52ce 100644 --- a/bert/model.py +++ b/bert/model.py @@ -1,10 +1,7 @@ from typing import Optional from dataclasses import dataclass -from mlx.utils import tree_unflatten, tree_map -from mlx.nn.layers.base import Module -from mlx.nn.layers.linear import Linear -from mlx.nn.layers.normalization import LayerNorm -from transformers import AutoTokenizer +from transformers import BertTokenizer +from mlx.utils import tree_unflatten import mlx.core as mx import mlx.nn as nn @@ -37,7 +34,7 @@ model_configs = { } -class MultiHeadAttention(Module): +class MultiHeadAttention(nn.Module): """ Minor update to the MultiHeadAttention module to ensure that the projections use bias. @@ -67,10 +64,10 @@ class MultiHeadAttention(Module): value_output_dims = value_output_dims or dims self.num_heads = num_heads - self.query_proj = Linear(query_input_dims, dims, True) - self.key_proj = Linear(key_input_dims, dims, True) - self.value_proj = Linear(value_input_dims, value_dims, True) - self.out_proj = Linear(value_dims, value_output_dims, True) + self.query_proj = nn.Linear(query_input_dims, dims, True) + self.key_proj = nn.Linear(key_input_dims, dims, True) + self.value_proj = nn.Linear(value_input_dims, value_dims, True) + self.out_proj = nn.Linear(value_dims, value_output_dims, True) def __call__(self, queries, keys, values, mask=None): queries = self.query_proj(queries) @@ -105,7 +102,7 @@ class MultiHeadAttention(Module): return mask -class TransformerEncoderLayer(Module): +class TransformerEncoderLayer(nn.Module): """ A transformer encoder layer with (the original BERT) post-normalization. """ @@ -120,10 +117,10 @@ class TransformerEncoderLayer(Module): super().__init__() mlp_dims = mlp_dims or dims * 4 self.attention = MultiHeadAttention(dims, num_heads) - self.ln1 = LayerNorm(dims, eps=layer_norm_eps) - self.ln2 = LayerNorm(dims, eps=layer_norm_eps) - self.linear1 = Linear(dims, mlp_dims) - self.linear2 = Linear(mlp_dims, dims) + self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps) + self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps) + self.linear1 = nn.Linear(dims, mlp_dims) + self.linear2 = nn.Linear(mlp_dims, dims) self.gelu = nn.GELU() def __call__(self, x, mask): @@ -138,7 +135,7 @@ class TransformerEncoderLayer(Module): return x -class TransformerEncoder(Module): +class TransformerEncoder(nn.Module): def __init__( self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None ): @@ -149,8 +146,8 @@ class TransformerEncoder(Module): ] def __call__(self, x, mask): - for l in self.layers: - x = l(x, mask) + for layer in self.layers: + x = layer(x, mask) return x @@ -196,23 +193,28 @@ class Bert(nn.Module): return y, mx.tanh(self.pooler(y[:, 0])) +def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]: + # load the weights npz + weights = mx.load(weights_path) + weights = tree_unflatten(list(weights.items())) + # create and update the model + model = Bert(model_configs[bert_model]) + model.update(weights) + + tokenizer = BertTokenizer.from_pretrained(bert_model) + + return model, tokenizer + + def run(bert_model: str, mlx_model: str): + model, tokenizer = load_model(bert_model, mlx_model) + batch = [ "This is an example of BERT working on MLX.", "A second string", "This is another string.", ] - - model = Bert(model_configs[bert_model]) - - weights = mx.load(mlx_model) - weights = tree_unflatten(list(weights.items())) - weights = tree_map(lambda p: mx.array(p), weights) - - model.update(weights) - - tokenizer = AutoTokenizer.from_pretrained(bert_model) - + tokens = tokenizer(batch, return_tensors="np", padding=True) tokens = {key: mx.array(v) for key, v in tokens.items()} @@ -228,7 +230,7 @@ def run(bert_model: str, mlx_model: str): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.") + parser = argparse.ArgumentParser(description="Run the BERT model using MLX.") parser.add_argument( "--bert-model", type=str, @@ -239,8 +241,8 @@ if __name__ == "__main__": "--mlx-model", type=str, default="weights/bert-base-uncased.npz", - help="The output path for the MLX BERT weights.", + help="The path of the stored MLX BERT weights (npz file).", ) args = parser.parse_args() - run(args.bert_model, args.mlx_model) \ No newline at end of file + run(args.bert_model, args.mlx_model) From 04350eb0a6a37ca715b0f52935dffab525683888 Mon Sep 17 00:00:00 2001 From: Joe Barrow Date: Sat, 9 Dec 2023 10:48:34 -0500 Subject: [PATCH 4/6] Updating README --- bert/README.md | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/bert/README.md b/bert/README.md index bb856ed3..29628cba 100644 --- a/bert/README.md +++ b/bert/README.md @@ -12,7 +12,31 @@ python convert.py \ --mlx-model weights/bert-base-uncased.npz ``` -## Run the Model +## Usage + +To use the `Bert` model in your own code, you can load it with: + +```python +from model import Bert, load_model + +model, tokenizer = load_model( + "bert-base-uncased", + "weights/bert-base-uncased.npz") + +batch = ["This is an example of BERT working on MLX."] +tokens = tokenizer(batch, return_tensors="np", padding=True) +tokens = {key: mx.array(v) for key, v in tokens.items()} + +output, pooled = model(**tokens) +``` + +The `output` contains a `Batch x Tokens x Dims` tensor, representing a vector for every input token. +If you want to train anything at a **token-level**, you'll want to use this. + +The `pooled` contains a `Batch x Dims` tensor, which is the pooled representation for each input. +If you want to train a **classification** model, you'll want to use this. + +## Comparison with 🤗 `transformers` Implementation In order to run the model, and have it forward inference on a batch of examples: From 187798967cb32a610b921cac0fb28caf6b4b14a3 Mon Sep 17 00:00:00 2001 From: Joe Barrow Date: Sat, 9 Dec 2023 10:52:55 -0500 Subject: [PATCH 5/6] Requirements for running BERT --- bert/requirements.txt | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 bert/requirements.txt diff --git a/bert/requirements.txt b/bert/requirements.txt new file mode 100644 index 00000000..24266334 --- /dev/null +++ b/bert/requirements.txt @@ -0,0 +1,3 @@ +mlx +transformers +numpy \ No newline at end of file From c5733b48fdcbb9d7851ca61bec8860223242d64f Mon Sep 17 00:00:00 2001 From: Joe Barrow Date: Sat, 9 Dec 2023 12:01:58 -0500 Subject: [PATCH 6/6] Updating README for current example, making python>=3.8 compatibile, and fixing code type --- bert/README.md | 26 ++++++++++++++------------ bert/model.py | 6 +++--- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/bert/README.md b/bert/README.md index 29628cba..cea738df 100644 --- a/bert/README.md +++ b/bert/README.md @@ -49,12 +49,13 @@ python model.py \ Which will show the following outputs: ``` MLX BERT: -[[[-0.17057164 0.08602728 -0.12471077 ... -0.09469379 -0.00275938 - 0.28314582] - [ 0.15222196 -0.48997563 -0.26665813 ... -0.19935863 -0.17162783 - -0.51360303] - [ 0.9460105 0.1358298 -0.2945672 ... 0.00868467 -0.90271163 - -0.2785422 ]]] +[[[-0.52508914 -0.1993871 -0.28210318 ... -0.61125606 0.19114694 + 0.8227601 ] + [-0.8783862 -0.37107834 -0.52238125 ... -0.5067165 1.0847603 + 0.31066895] + [-0.70010054 -0.5424497 -0.26593682 ... -0.2688697 0.38338926 + 0.6557663 ] + ... ``` They can be compared against the 🤗 implementation with: @@ -67,10 +68,11 @@ python hf_model.py \ Which will show: ``` HF BERT: -[[[-0.17057131 0.08602707 -0.12471108 ... -0.09469365 -0.00275959 - 0.28314728] - [ 0.15222463 -0.48997375 -0.26665992 ... -0.19936043 -0.17162988 - -0.5136028 ] - [ 0.946011 0.13582966 -0.29456618 ... 0.00868565 -0.90271175 - -0.27854213]]] +[[[-0.52508944 -0.1993877 -0.28210333 ... -0.6112575 0.19114678 + 0.8227603 ] + [-0.878387 -0.371079 -0.522381 ... -0.50671494 1.0847601 + 0.31066933] + [-0.7001008 -0.5424504 -0.26593733 ... -0.26887015 0.38339025 + 0.65576553] + ... ``` diff --git a/bert/model.py b/bert/model.py index 318f52ce..446919b1 100644 --- a/bert/model.py +++ b/bert/model.py @@ -85,7 +85,7 @@ class MultiHeadAttention(nn.Module): scale = math.sqrt(1 / queries.shape[-1]) scores = (queries * scale) @ keys if mask is not None: - mask = self.converrt_mask_to_additive_causal_mask(mask) + mask = self.convert_mask_to_additive_causal_mask(mask) mask = mx.expand_dims(mask, (1, 2)) mask = mx.broadcast_to(mask, scores.shape) scores = scores + mask.astype(scores.dtype) @@ -94,7 +94,7 @@ class MultiHeadAttention(nn.Module): return self.out_proj(values_hat) - def converrt_mask_to_additive_causal_mask( + def convert_mask_to_additive_causal_mask( self, mask: mx.array, dtype: mx.Dtype = mx.float32 ) -> mx.array: mask = mask == 0 @@ -186,7 +186,7 @@ class Bert(nn.Module): self, input_ids: mx.array, token_type_ids: mx.array, - attention_mask: mx.array | None = None, + attention_mask: Optional[mx.array] = None, ) -> tuple[mx.array, mx.array]: x = self.embeddings(input_ids, token_type_ids) y = self.encoder(x, attention_mask)