Rescale output before projecting on vocab

This commit is contained in:
Juarez Bochi 2023-12-18 13:43:03 -05:00
parent 511f572b6c
commit 36fd88509e
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6

View File

@ -266,6 +266,8 @@ class T5(nn.Module):
self.encoder = TransformerEncoder(config) self.encoder = TransformerEncoder(config)
self.decoder = TransformerDecoder(config) self.decoder = TransformerDecoder(config)
self.lm_head = OutputHead(config) self.lm_head = OutputHead(config)
self.tie_word_embeddings = config.tie_word_embeddings
self.model_dim = config.d_model
def encode(self, inputs: mx.array): def encode(self, inputs: mx.array):
return self.encoder(self.wte(inputs)) return self.encoder(self.wte(inputs))
@ -287,6 +289,10 @@ class T5(nn.Module):
y, cache = self.decoder( y, cache = self.decoder(
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
) )
if self.tie_word_embeddings:
# Rescale output before projecting on vocab
# See https://github.com/huggingface/transformers/blob/71d47f0ad498b7649f11d3a9cca3cd3585e4341f/src/transformers/models/t5/modeling_t5.py#L1766C9-L1769C71
y *= self.model_dim ** -0.5
return self.lm_head(y), cache return self.lm_head(y), cache
def __call__( def __call__(