mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Rescale output before projecting on vocab
This commit is contained in:
parent
511f572b6c
commit
36fd88509e
6
t5/t5.py
6
t5/t5.py
@ -266,6 +266,8 @@ class T5(nn.Module):
|
||||
self.encoder = TransformerEncoder(config)
|
||||
self.decoder = TransformerDecoder(config)
|
||||
self.lm_head = OutputHead(config)
|
||||
self.tie_word_embeddings = config.tie_word_embeddings
|
||||
self.model_dim = config.d_model
|
||||
|
||||
def encode(self, inputs: mx.array):
|
||||
return self.encoder(self.wte(inputs))
|
||||
@ -287,6 +289,10 @@ class T5(nn.Module):
|
||||
y, cache = self.decoder(
|
||||
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
|
||||
)
|
||||
if self.tie_word_embeddings:
|
||||
# Rescale output before projecting on vocab
|
||||
# See https://github.com/huggingface/transformers/blob/71d47f0ad498b7649f11d3a9cca3cd3585e4341f/src/transformers/models/t5/modeling_t5.py#L1766C9-L1769C71
|
||||
y *= self.model_dim ** -0.5
|
||||
return self.lm_head(y), cache
|
||||
|
||||
def __call__(
|
||||
|
Loading…
Reference in New Issue
Block a user