From daea1dcddfedaeff3dcbb30950af3cac30058463 Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Sun, 17 Dec 2023 08:40:10 -0500 Subject: [PATCH] Use position bias in decoder --- t5/t5.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/t5/t5.py b/t5/t5.py index bf79aa4f..c165a964 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -232,13 +232,13 @@ class TransformerDecoderLayer(nn.Module): self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False) self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False) - def __call__(self, x, memory, x_mask, memory_mask): + def __call__(self, x, memory, x_mask, memory_mask, position_bias=None): y = self.ln1(x) - y = self.self_attention(y, y, y, x_mask) + y, position_bias = self.self_attention(y, y, y, x_mask, position_bias=position_bias) x = x + y y = self.ln2(x) - y = self.cross_attention(x, memory, memory, memory_mask) + y, _ = self.cross_attention(x, memory, memory, memory_mask) x = x + y y = self.ln3(x) @@ -247,7 +247,7 @@ class TransformerDecoderLayer(nn.Module): y = self.linear2(y) x = x + y - return x + return x, position_bias class TransformerDecoder(nn.Module): @@ -260,8 +260,11 @@ class TransformerDecoder(nn.Module): self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) def __call__(self, x, memory, x_mask, memory_mask): + position_bias = None for layer in self.layers: - x = layer(x, memory, x_mask, memory_mask) + x, position_bias = layer( + x, memory, x_mask, memory_mask, position_bias=position_bias + ) x = self.ln(x) return x @@ -298,7 +301,7 @@ class T5(nn.Module): mask = MultiHeadAttention.create_additive_causal_mask(decoder_n_tokens) mask = mask.astype(x.dtype) - y, cache = self.decoder( + y = self.decoder( x=decoder_inputs, x_mask=mask, memory=y, memory_mask=None ) # , cache) return self.lm_head(y), cache