bert encoder inherits from nn.Module now (#571)

This commit is contained in:
Race 2024-03-13 10:24:21 -07:00 committed by GitHub
parent 14fe868825
commit 376bb9cc44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -87,6 +87,7 @@ class TransformerEncoder(nn.Module):
class BertEmbeddings(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.token_type_embeddings = nn.Embedding(2, config.dim)
self.position_embeddings = nn.Embedding(
@ -107,6 +108,7 @@ class BertEmbeddings(nn.Module):
class Bert(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.embeddings = BertEmbeddings(config)
self.encoder = TransformerEncoder(
num_layers=config.num_hidden_layers,