mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 10:41:18 +08:00
Cleaner masking code
This commit is contained in:
parent
5d4838b02e
commit
a577abc313
@ -123,22 +123,13 @@ class Bert(nn.Module):
|
|||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# convert 0's to -infs, 1's to 0's, and make it broadcastable
|
# convert 0's to -infs, 1's to 0's, and make it broadcastable
|
||||||
attention_mask = self.convert_mask_to_additive_causal_mask(attention_mask)
|
attention_mask = mx.log(attention_mask)
|
||||||
attention_mask = mx.expand_dims(attention_mask, (1, 2))
|
attention_mask = mx.expand_dims(attention_mask, (1, 2))
|
||||||
|
|
||||||
y = self.encoder(x, attention_mask)
|
y = self.encoder(x, attention_mask)
|
||||||
return y, mx.tanh(self.pooler(y[:, 0]))
|
return y, mx.tanh(self.pooler(y[:, 0]))
|
||||||
|
|
||||||
|
|
||||||
def convert_mask_to_additive_causal_mask(
|
|
||||||
self, mask: mx.array, dtype: mx.Dtype = mx.float32
|
|
||||||
) -> mx.array:
|
|
||||||
mask = mask == 0
|
|
||||||
mask = mask.astype(dtype) * -1e9
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]:
|
def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]:
|
||||||
# load the weights npz
|
# load the weights npz
|
||||||
weights = mx.load(weights_path)
|
weights = mx.load(weights_path)
|
||||||
|
Loading…
Reference in New Issue
Block a user