From 376bb9cc4498e42390b5c372482554160735dbfe Mon Sep 17 00:00:00 2001 From: Race <43013378+raceee@users.noreply.github.com> Date: Wed, 13 Mar 2024 10:24:21 -0700 Subject: [PATCH] bert encoder inherits from nn.Module now (#571) --- bert/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bert/model.py b/bert/model.py index f1e05313..381dd5e2 100644 --- a/bert/model.py +++ b/bert/model.py @@ -87,6 +87,7 @@ class TransformerEncoder(nn.Module): class BertEmbeddings(nn.Module): def __init__(self, config: ModelArgs): + super().__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.dim) self.token_type_embeddings = nn.Embedding(2, config.dim) self.position_embeddings = nn.Embedding( @@ -107,6 +108,7 @@ class BertEmbeddings(nn.Module): class Bert(nn.Module): def __init__(self, config: ModelArgs): + super().__init__() self.embeddings = BertEmbeddings(config) self.encoder = TransformerEncoder( num_layers=config.num_hidden_layers,