Fix scaling when embeddings are tied (#591)

This commit is contained in:
Abdul Fatir 2024-03-18 21:41:07 +01:00 committed by GitHub
parent e4b19bb9e1
commit e05e502c34
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -315,9 +315,9 @@ class T5(nn.Module):
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
) )
if not self.tie_word_embeddings: if not self.tie_word_embeddings:
y *= self.model_dim**-0.5
y = self.lm_head(y) y = self.lm_head(y)
else: else:
y *= self.model_dim**-0.5
y = y @ self.wte.weight.T y = y @ self.wte.weight.T
return y, cache return y, cache