run precommit

This commit is contained in:
junwoo-yun 2023-12-25 07:52:19 +08:00
parent 0e0557b756
commit 88a94b9db8

View File

@ -156,13 +156,13 @@ class TransformerEncoder(Module):
num_heads: int, num_heads: int,
mlp_dims: Optional[int] = None, mlp_dims: Optional[int] = None,
dropout: float = 0.0, dropout: float = 0.0,
norm_first: bool = False,
activation=relu, activation=relu,
norm_first: bool = False,
): ):
super().__init__() super().__init__()
self.layers = [ self.layers = [
TransformerEncoderLayer( TransformerEncoderLayer(
dims, num_heads, mlp_dims, dropout, norm_first, activation dims, num_heads, mlp_dims, dropout, activation, norm_first
) )
for i in range(num_layers) for i in range(num_layers)
] ]
@ -246,13 +246,13 @@ class TransformerDecoder(Module):
num_heads: int, num_heads: int,
mlp_dims: Optional[int] = None, mlp_dims: Optional[int] = None,
dropout: float = 0.0, dropout: float = 0.0,
norm_first: bool = False,
activation=relu, activation=relu,
norm_first: bool = False,
): ):
super().__init__() super().__init__()
self.layers = [ self.layers = [
TransformerDecoderLayer( TransformerDecoderLayer(
dims, num_heads, mlp_dims, dropout, norm_first, activation dims, num_heads, mlp_dims, dropout, activation, norm_first
) )
for i in range(num_layers) for i in range(num_layers)
] ]