mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 13:07:51 +08:00
run precommit
This commit is contained in:
parent
0e0557b756
commit
88a94b9db8
@ -156,13 +156,13 @@ class TransformerEncoder(Module):
|
||||
num_heads: int,
|
||||
mlp_dims: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
norm_first: bool = False,
|
||||
activation=relu,
|
||||
norm_first: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerEncoderLayer(
|
||||
dims, num_heads, mlp_dims, dropout, norm_first, activation
|
||||
dims, num_heads, mlp_dims, dropout, activation, norm_first
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
@ -246,13 +246,13 @@ class TransformerDecoder(Module):
|
||||
num_heads: int,
|
||||
mlp_dims: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
norm_first: bool = False,
|
||||
activation=relu,
|
||||
norm_first: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerDecoderLayer(
|
||||
dims, num_heads, mlp_dims, dropout, norm_first, activation
|
||||
dims, num_heads, mlp_dims, dropout, activation, norm_first
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user