From 72581e5c1a07f42fc18aff21b2416a3849ad08d2 Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Mon, 18 Dec 2023 15:50:29 -0500 Subject: [PATCH] Fix attention for 3b model --- t5/t5.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/t5/t5.py b/t5/t5.py index 8f344d11..cd884c48 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -93,11 +93,12 @@ class RelativePositionBias(nn.Module): class MultiHeadAttention(nn.Module): def __init__(self, config: T5Config): super().__init__() + inner_dim = config.d_kv * config.num_heads self.num_heads = config.num_heads - self.query_proj = nn.Linear(config.d_model, config.d_model, bias=False) - self.key_proj = nn.Linear(config.d_model, config.d_model, bias=False) - self.value_proj = nn.Linear(config.d_model, config.d_model, bias=False) - self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False) + self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False) def __call__( self,