From 7dcf2b688d021cdc8244984a56de10d4337b4ed5 Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Sun, 17 Dec 2023 08:34:21 -0500 Subject: [PATCH] Fix decoder mask --- t5/t5.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/t5/t5.py b/t5/t5.py index e955e372..bf79aa4f 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -124,9 +124,12 @@ class MultiHeadAttention(nn.Module): queries = self.query_proj(queries) keys = self.key_proj(keys) values = self.value_proj(values) + # print(f"queries: {queries}, {queries.abs().sum()}") + # print(f"keys: {keys}, {keys.abs().sum()}") + # print(f"values: {values}, {values.abs().sum()}") num_heads = self.num_heads - B, L, D = queries.shape + B, L, _ = queries.shape _, S, _ = keys.shape queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) @@ -289,11 +292,12 @@ class T5(nn.Module): x = self.wte(inputs) y = self.encoder(x, mask=None) # , cache) - if x.shape[1] > 1 and mask is None: - mask = MultiHeadAttention.create_additive_causal_mask(x.shape[1]) + decoder_inputs = self.wte(decoder_inputs) + decoder_n_tokens = decoder_inputs.shape[1] + if decoder_n_tokens > 1 and mask is None: + mask = MultiHeadAttention.create_additive_causal_mask(decoder_n_tokens) mask = mask.astype(x.dtype) - decoder_inputs = self.wte(decoder_inputs) y, cache = self.decoder( x=decoder_inputs, x_mask=mask, memory=y, memory_mask=None ) # , cache)