diff --git a/t5/t5.py b/t5/t5.py index 3812393c..556c2503 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -315,9 +315,9 @@ class T5(nn.Module): inputs, memory=memory, mask=mask, memory_mask=None, cache=cache ) if not self.tie_word_embeddings: - y *= self.model_dim**-0.5 y = self.lm_head(y) else: + y *= self.model_dim**-0.5 y = y @ self.wte.weight.T return y, cache