diff --git a/bert/model.py b/bert/model.py index f1e05313..381dd5e2 100644 --- a/bert/model.py +++ b/bert/model.py @@ -87,6 +87,7 @@ class TransformerEncoder(nn.Module): class BertEmbeddings(nn.Module): def __init__(self, config: ModelArgs): + super().__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.dim) self.token_type_embeddings = nn.Embedding(2, config.dim) self.position_embeddings = nn.Embedding( @@ -107,6 +108,7 @@ class BertEmbeddings(nn.Module): class Bert(nn.Module): def __init__(self, config: ModelArgs): + super().__init__() self.embeddings = BertEmbeddings(config) self.encoder = TransformerEncoder( num_layers=config.num_hidden_layers,