mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Use position bias in decoder
This commit is contained in:
parent
7dcf2b688d
commit
daea1dcddf
15
t5/t5.py
15
t5/t5.py
@ -232,13 +232,13 @@ class TransformerDecoderLayer(nn.Module):
|
||||
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
|
||||
|
||||
def __call__(self, x, memory, x_mask, memory_mask):
|
||||
def __call__(self, x, memory, x_mask, memory_mask, position_bias=None):
|
||||
y = self.ln1(x)
|
||||
y = self.self_attention(y, y, y, x_mask)
|
||||
y, position_bias = self.self_attention(y, y, y, x_mask, position_bias=position_bias)
|
||||
x = x + y
|
||||
|
||||
y = self.ln2(x)
|
||||
y = self.cross_attention(x, memory, memory, memory_mask)
|
||||
y, _ = self.cross_attention(x, memory, memory, memory_mask)
|
||||
x = x + y
|
||||
|
||||
y = self.ln3(x)
|
||||
@ -247,7 +247,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
y = self.linear2(y)
|
||||
x = x + y
|
||||
|
||||
return x
|
||||
return x, position_bias
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
@ -260,8 +260,11 @@ class TransformerDecoder(nn.Module):
|
||||
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
|
||||
def __call__(self, x, memory, x_mask, memory_mask):
|
||||
position_bias = None
|
||||
for layer in self.layers:
|
||||
x = layer(x, memory, x_mask, memory_mask)
|
||||
x, position_bias = layer(
|
||||
x, memory, x_mask, memory_mask, position_bias=position_bias
|
||||
)
|
||||
x = self.ln(x)
|
||||
|
||||
return x
|
||||
@ -298,7 +301,7 @@ class T5(nn.Module):
|
||||
mask = MultiHeadAttention.create_additive_causal_mask(decoder_n_tokens)
|
||||
mask = mask.astype(x.dtype)
|
||||
|
||||
y, cache = self.decoder(
|
||||
y = self.decoder(
|
||||
x=decoder_inputs, x_mask=mask, memory=y, memory_mask=None
|
||||
) # , cache)
|
||||
return self.lm_head(y), cache
|
||||
|
Loading…
Reference in New Issue
Block a user