From 88a94b9db8584f9e1d68116c46d8b8d4da06adcb Mon Sep 17 00:00:00 2001 From: junwoo-yun Date: Mon, 25 Dec 2023 07:52:19 +0800 Subject: [PATCH] run precommit --- python/mlx/nn/layers/transformer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index f1d9b42de..51a7f55f4 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -156,13 +156,13 @@ class TransformerEncoder(Module): num_heads: int, mlp_dims: Optional[int] = None, dropout: float = 0.0, - norm_first: bool = False, activation=relu, + norm_first: bool = False, ): super().__init__() self.layers = [ TransformerEncoderLayer( - dims, num_heads, mlp_dims, dropout, norm_first, activation + dims, num_heads, mlp_dims, dropout, activation, norm_first ) for i in range(num_layers) ] @@ -246,13 +246,13 @@ class TransformerDecoder(Module): num_heads: int, mlp_dims: Optional[int] = None, dropout: float = 0.0, - norm_first: bool = False, activation=relu, + norm_first: bool = False, ): super().__init__() self.layers = [ TransformerDecoderLayer( - dims, num_heads, mlp_dims, dropout, norm_first, activation + dims, num_heads, mlp_dims, dropout, activation, norm_first ) for i in range(num_layers) ]