diff --git a/t5/t5.py b/t5/t5.py index 6a21c791..a166d4c4 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -266,6 +266,8 @@ class T5(nn.Module): self.encoder = TransformerEncoder(config) self.decoder = TransformerDecoder(config) self.lm_head = OutputHead(config) + self.tie_word_embeddings = config.tie_word_embeddings + self.model_dim = config.d_model def encode(self, inputs: mx.array): return self.encoder(self.wte(inputs)) @@ -287,6 +289,10 @@ class T5(nn.Module): y, cache = self.decoder( inputs, memory=memory, mask=mask, memory_mask=None, cache=cache ) + if self.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/huggingface/transformers/blob/71d47f0ad498b7649f11d3a9cca3cd3585e4341f/src/transformers/models/t5/modeling_t5.py#L1766C9-L1769C71 + y *= self.model_dim ** -0.5 return self.lm_head(y), cache def __call__(