mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
bert encoder inherits from nn.Module now (#571)
This commit is contained in:
parent
14fe868825
commit
376bb9cc44
@ -87,6 +87,7 @@ class TransformerEncoder(nn.Module):
|
||||
|
||||
class BertEmbeddings(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim)
|
||||
self.token_type_embeddings = nn.Embedding(2, config.dim)
|
||||
self.position_embeddings = nn.Embedding(
|
||||
@ -107,6 +108,7 @@ class BertEmbeddings(nn.Module):
|
||||
|
||||
class Bert(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
self.encoder = TransformerEncoder(
|
||||
num_layers=config.num_hidden_layers,
|
||||
|
Loading…
Reference in New Issue
Block a user