mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
bert encoder inherits from nn.Module now (#571)
This commit is contained in:
parent
14fe868825
commit
376bb9cc44
@ -87,6 +87,7 @@ class TransformerEncoder(nn.Module):
|
|||||||
|
|
||||||
class BertEmbeddings(nn.Module):
|
class BertEmbeddings(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim)
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim)
|
||||||
self.token_type_embeddings = nn.Embedding(2, config.dim)
|
self.token_type_embeddings = nn.Embedding(2, config.dim)
|
||||||
self.position_embeddings = nn.Embedding(
|
self.position_embeddings = nn.Embedding(
|
||||||
@ -107,6 +108,7 @@ class BertEmbeddings(nn.Module):
|
|||||||
|
|
||||||
class Bert(nn.Module):
|
class Bert(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
self.embeddings = BertEmbeddings(config)
|
self.embeddings = BertEmbeddings(config)
|
||||||
self.encoder = TransformerEncoder(
|
self.encoder = TransformerEncoder(
|
||||||
num_layers=config.num_hidden_layers,
|
num_layers=config.num_hidden_layers,
|
||||||
|
Loading…
Reference in New Issue
Block a user