Use position bias in decoder

This commit is contained in:
Juarez Bochi 2023-12-17 08:40:10 -05:00
parent 7dcf2b688d
commit daea1dcddf
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6

View File

@ -232,13 +232,13 @@ class TransformerDecoderLayer(nn.Module):
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
def __call__(self, x, memory, x_mask, memory_mask):
def __call__(self, x, memory, x_mask, memory_mask, position_bias=None):
y = self.ln1(x)
y = self.self_attention(y, y, y, x_mask)
y, position_bias = self.self_attention(y, y, y, x_mask, position_bias=position_bias)
x = x + y
y = self.ln2(x)
y = self.cross_attention(x, memory, memory, memory_mask)
y, _ = self.cross_attention(x, memory, memory, memory_mask)
x = x + y
y = self.ln3(x)
@ -247,7 +247,7 @@ class TransformerDecoderLayer(nn.Module):
y = self.linear2(y)
x = x + y
return x
return x, position_bias
class TransformerDecoder(nn.Module):
@ -260,8 +260,11 @@ class TransformerDecoder(nn.Module):
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
def __call__(self, x, memory, x_mask, memory_mask):
position_bias = None
for layer in self.layers:
x = layer(x, memory, x_mask, memory_mask)
x, position_bias = layer(
x, memory, x_mask, memory_mask, position_bias=position_bias
)
x = self.ln(x)
return x
@ -298,7 +301,7 @@ class T5(nn.Module):
mask = MultiHeadAttention.create_additive_causal_mask(decoder_n_tokens)
mask = mask.astype(x.dtype)
y, cache = self.decoder(
y = self.decoder(
x=decoder_inputs, x_mask=mask, memory=y, memory_mask=None
) # , cache)
return self.lm_head(y), cache