From d873e10dfe248d33c7db4c3f896e1d07c372af18 Mon Sep 17 00:00:00 2001 From: Joe Barrow Date: Sat, 9 Dec 2023 12:01:58 -0500 Subject: [PATCH] Updating README for current example, making python>=3.8 compatibile, and fixing code type --- bert/README.md | 26 ++++++++++++++------------ bert/model.py | 6 +++--- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/bert/README.md b/bert/README.md index 29628cba..cea738df 100644 --- a/bert/README.md +++ b/bert/README.md @@ -49,12 +49,13 @@ python model.py \ Which will show the following outputs: ``` MLX BERT: -[[[-0.17057164 0.08602728 -0.12471077 ... -0.09469379 -0.00275938 - 0.28314582] - [ 0.15222196 -0.48997563 -0.26665813 ... -0.19935863 -0.17162783 - -0.51360303] - [ 0.9460105 0.1358298 -0.2945672 ... 0.00868467 -0.90271163 - -0.2785422 ]]] +[[[-0.52508914 -0.1993871 -0.28210318 ... -0.61125606 0.19114694 + 0.8227601 ] + [-0.8783862 -0.37107834 -0.52238125 ... -0.5067165 1.0847603 + 0.31066895] + [-0.70010054 -0.5424497 -0.26593682 ... -0.2688697 0.38338926 + 0.6557663 ] + ... ``` They can be compared against the 🤗 implementation with: @@ -67,10 +68,11 @@ python hf_model.py \ Which will show: ``` HF BERT: -[[[-0.17057131 0.08602707 -0.12471108 ... -0.09469365 -0.00275959 - 0.28314728] - [ 0.15222463 -0.48997375 -0.26665992 ... -0.19936043 -0.17162988 - -0.5136028 ] - [ 0.946011 0.13582966 -0.29456618 ... 0.00868565 -0.90271175 - -0.27854213]]] +[[[-0.52508944 -0.1993877 -0.28210333 ... -0.6112575 0.19114678 + 0.8227603 ] + [-0.878387 -0.371079 -0.522381 ... -0.50671494 1.0847601 + 0.31066933] + [-0.7001008 -0.5424504 -0.26593733 ... -0.26887015 0.38339025 + 0.65576553] + ... ``` diff --git a/bert/model.py b/bert/model.py index 318f52ce..446919b1 100644 --- a/bert/model.py +++ b/bert/model.py @@ -85,7 +85,7 @@ class MultiHeadAttention(nn.Module): scale = math.sqrt(1 / queries.shape[-1]) scores = (queries * scale) @ keys if mask is not None: - mask = self.converrt_mask_to_additive_causal_mask(mask) + mask = self.convert_mask_to_additive_causal_mask(mask) mask = mx.expand_dims(mask, (1, 2)) mask = mx.broadcast_to(mask, scores.shape) scores = scores + mask.astype(scores.dtype) @@ -94,7 +94,7 @@ class MultiHeadAttention(nn.Module): return self.out_proj(values_hat) - def converrt_mask_to_additive_causal_mask( + def convert_mask_to_additive_causal_mask( self, mask: mx.array, dtype: mx.Dtype = mx.float32 ) -> mx.array: mask = mask == 0 @@ -186,7 +186,7 @@ class Bert(nn.Module): self, input_ids: mx.array, token_type_ids: mx.array, - attention_mask: mx.array | None = None, + attention_mask: Optional[mx.array] = None, ) -> tuple[mx.array, mx.array]: x = self.embeddings(input_ids, token_type_ids) y = self.encoder(x, attention_mask)