From a577abc31320f3c03f90dc20201421a375a73ae5 Mon Sep 17 00:00:00 2001 From: Joe Barrow Date: Sat, 9 Dec 2023 21:21:24 -0500 Subject: [PATCH] Cleaner masking code --- bert/model.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/bert/model.py b/bert/model.py index d4dccfac..4666a78d 100644 --- a/bert/model.py +++ b/bert/model.py @@ -123,20 +123,11 @@ class Bert(nn.Module): if attention_mask is not None: # convert 0's to -infs, 1's to 0's, and make it broadcastable - attention_mask = self.convert_mask_to_additive_causal_mask(attention_mask) + attention_mask = mx.log(attention_mask) attention_mask = mx.expand_dims(attention_mask, (1, 2)) y = self.encoder(x, attention_mask) return y, mx.tanh(self.pooler(y[:, 0])) - - - def convert_mask_to_additive_causal_mask( - self, mask: mx.array, dtype: mx.Dtype = mx.float32 - ) -> mx.array: - mask = mask == 0 - mask = mask.astype(dtype) * -1e9 - return mask - def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]: