diff --git a/t5/t5.py b/t5/t5.py index 8f344d11..cd884c48 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -93,11 +93,12 @@ class RelativePositionBias(nn.Module): class MultiHeadAttention(nn.Module): def __init__(self, config: T5Config): super().__init__() + inner_dim = config.d_kv * config.num_heads self.num_heads = config.num_heads - self.query_proj = nn.Linear(config.d_model, config.d_model, bias=False) - self.key_proj = nn.Linear(config.d_model, config.d_model, bias=False) - self.value_proj = nn.Linear(config.d_model, config.d_model, bias=False) - self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False) + self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False) def __call__( self,