Fix decoder mask

This commit is contained in:
Juarez Bochi 2023-12-17 08:34:21 -05:00
parent f26e81ccc9
commit 7dcf2b688d
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6

View File

@ -124,9 +124,12 @@ class MultiHeadAttention(nn.Module):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
# print(f"queries: {queries}, {queries.abs().sum()}")
# print(f"keys: {keys}, {keys.abs().sum()}")
# print(f"values: {values}, {values.abs().sum()}")
num_heads = self.num_heads
B, L, D = queries.shape
B, L, _ = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
@ -289,11 +292,12 @@ class T5(nn.Module):
x = self.wte(inputs)
y = self.encoder(x, mask=None) # , cache)
if x.shape[1] > 1 and mask is None:
mask = MultiHeadAttention.create_additive_causal_mask(x.shape[1])
decoder_inputs = self.wte(decoder_inputs)
decoder_n_tokens = decoder_inputs.shape[1]
if decoder_n_tokens > 1 and mask is None:
mask = MultiHeadAttention.create_additive_causal_mask(decoder_n_tokens)
mask = mask.astype(x.dtype)
decoder_inputs = self.wte(decoder_inputs)
y, cache = self.decoder(
x=decoder_inputs, x_mask=mask, memory=y, memory_mask=None
) # , cache)