mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
Cleaner masking code
This commit is contained in:
parent
5d4838b02e
commit
a577abc313
@ -123,20 +123,11 @@ class Bert(nn.Module):
|
||||
|
||||
if attention_mask is not None:
|
||||
# convert 0's to -infs, 1's to 0's, and make it broadcastable
|
||||
attention_mask = self.convert_mask_to_additive_causal_mask(attention_mask)
|
||||
attention_mask = mx.log(attention_mask)
|
||||
attention_mask = mx.expand_dims(attention_mask, (1, 2))
|
||||
|
||||
y = self.encoder(x, attention_mask)
|
||||
return y, mx.tanh(self.pooler(y[:, 0]))
|
||||
|
||||
|
||||
def convert_mask_to_additive_causal_mask(
|
||||
self, mask: mx.array, dtype: mx.Dtype = mx.float32
|
||||
) -> mx.array:
|
||||
mask = mask == 0
|
||||
mask = mask.astype(dtype) * -1e9
|
||||
return mask
|
||||
|
||||
|
||||
|
||||
def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]:
|
||||
|
Loading…
Reference in New Issue
Block a user