diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index f1d9b42de..51a7f55f4 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -156,13 +156,13 @@ class TransformerEncoder(Module): num_heads: int, mlp_dims: Optional[int] = None, dropout: float = 0.0, - norm_first: bool = False, activation=relu, + norm_first: bool = False, ): super().__init__() self.layers = [ TransformerEncoderLayer( - dims, num_heads, mlp_dims, dropout, norm_first, activation + dims, num_heads, mlp_dims, dropout, activation, norm_first ) for i in range(num_layers) ] @@ -246,13 +246,13 @@ class TransformerDecoder(Module): num_heads: int, mlp_dims: Optional[int] = None, dropout: float = 0.0, - norm_first: bool = False, activation=relu, + norm_first: bool = False, ): super().__init__() self.layers = [ TransformerDecoderLayer( - dims, num_heads, mlp_dims, dropout, norm_first, activation + dims, num_heads, mlp_dims, dropout, activation, norm_first ) for i in range(num_layers) ]