mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 13:07:51 +08:00
run precommit
This commit is contained in:
parent
0e0557b756
commit
88a94b9db8
@ -156,13 +156,13 @@ class TransformerEncoder(Module):
|
|||||||
num_heads: int,
|
num_heads: int,
|
||||||
mlp_dims: Optional[int] = None,
|
mlp_dims: Optional[int] = None,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
norm_first: bool = False,
|
|
||||||
activation=relu,
|
activation=relu,
|
||||||
|
norm_first: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = [
|
self.layers = [
|
||||||
TransformerEncoderLayer(
|
TransformerEncoderLayer(
|
||||||
dims, num_heads, mlp_dims, dropout, norm_first, activation
|
dims, num_heads, mlp_dims, dropout, activation, norm_first
|
||||||
)
|
)
|
||||||
for i in range(num_layers)
|
for i in range(num_layers)
|
||||||
]
|
]
|
||||||
@ -246,13 +246,13 @@ class TransformerDecoder(Module):
|
|||||||
num_heads: int,
|
num_heads: int,
|
||||||
mlp_dims: Optional[int] = None,
|
mlp_dims: Optional[int] = None,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
norm_first: bool = False,
|
|
||||||
activation=relu,
|
activation=relu,
|
||||||
|
norm_first: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = [
|
self.layers = [
|
||||||
TransformerDecoderLayer(
|
TransformerDecoderLayer(
|
||||||
dims, num_heads, mlp_dims, dropout, norm_first, activation
|
dims, num_heads, mlp_dims, dropout, activation, norm_first
|
||||||
)
|
)
|
||||||
for i in range(num_layers)
|
for i in range(num_layers)
|
||||||
]
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user