From 62924d8135214a36b3bed21a6a2078fba7310d6f Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Thu, 14 Dec 2023 15:51:03 -0500 Subject: [PATCH] Pass config to all modules, fix ln --- t5/t5.py | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/t5/t5.py b/t5/t5.py index 0774a1fa..b2ec7717 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -46,14 +46,14 @@ class LayerNorm(nn.Module): class TransformerEncoderLayer(nn.Module): - def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None): + def __init__(self, config): super().__init__() - mlp_dims = mlp_dims or dims * 4 - self.attention = nn.MultiHeadAttention(dims, num_heads) - self.ln1 = LayerNorm(dims) - self.ln2 = LayerNorm(dims) - self.linear1 = nn.Linear(dims, mlp_dims, bias=False) - self.linear2 = nn.Linear(mlp_dims, dims, bias=False) + mlp_dims = config.d_ff or config.d_model * 4 + self.attention = nn.MultiHeadAttention(config.d_model, config.num_heads) + self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False) + self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False) def __call__(self, x, mask): y = self.ln1(x) @@ -70,15 +70,13 @@ class TransformerEncoderLayer(nn.Module): class TransformerEncoder(nn.Module): - def __init__( - self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None - ): + def __init__(self, config: ModelArgs): super().__init__() self.layers = [ - TransformerEncoderLayer(dims, num_heads, mlp_dims) - for _ in range(num_layers) + TransformerEncoderLayer(config) + for _ in range(config.num_layers) ] - self.ln = LayerNorm(dims) + self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) def __call__(self, x, mask): for layer in self.layers: @@ -91,12 +89,7 @@ class TransformerEncoder(nn.Module): class T5(nn.Module): def __init__(self, config: ModelArgs): self.wte = nn.Embedding(config.vocab_size, config.d_model) - self.encoder = TransformerEncoder( - num_layers=config.num_layers, - dims=config.d_model, - num_heads=config.num_heads, - mlp_dims=config.d_ff, - ) + self.encoder = TransformerEncoder(config) # self.decoder = TransformerDecoder(config) # self.lm_head = OutputHead(config)