mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
Fix scaling when embeddings are tied (#591)
This commit is contained in:
parent
e4b19bb9e1
commit
e05e502c34
2
t5/t5.py
2
t5/t5.py
@ -315,9 +315,9 @@ class T5(nn.Module):
|
|||||||
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
|
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
|
||||||
)
|
)
|
||||||
if not self.tie_word_embeddings:
|
if not self.tie_word_embeddings:
|
||||||
y *= self.model_dim**-0.5
|
|
||||||
y = self.lm_head(y)
|
y = self.lm_head(y)
|
||||||
else:
|
else:
|
||||||
|
y *= self.model_dim**-0.5
|
||||||
y = y @ self.wte.weight.T
|
y = y @ self.wte.weight.T
|
||||||
return y, cache
|
return y, cache
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user