run precommit

This commit is contained in:
junwoo-yun 2023-12-25 07:50:54 +08:00
parent 297e69017c
commit 0e0557b756

View File

@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import math import math
from typing import Any, Optional, Callable from typing import Any, Callable, Optional
import mlx.core as mx import mlx.core as mx
from mlx.nn.layers.activations import relu from mlx.nn.layers.activations import relu
@ -150,11 +150,20 @@ class TransformerEncoderLayer(Module):
class TransformerEncoder(Module): class TransformerEncoder(Module):
def __init__( def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout: float = 0.0, norm_first: bool = False, activation = relu self,
num_layers: int,
dims: int,
num_heads: int,
mlp_dims: Optional[int] = None,
dropout: float = 0.0,
norm_first: bool = False,
activation=relu,
): ):
super().__init__() super().__init__()
self.layers = [ self.layers = [
TransformerEncoderLayer(dims, num_heads, mlp_dims, dropout, norm_first, activation) TransformerEncoderLayer(
dims, num_heads, mlp_dims, dropout, norm_first, activation
)
for i in range(num_layers) for i in range(num_layers)
] ]
self.ln = LayerNorm(dims) self.ln = LayerNorm(dims)
@ -231,11 +240,20 @@ class TransformerDecoderLayer(Module):
class TransformerDecoder(Module): class TransformerDecoder(Module):
def __init__( def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout: float = 0.0, norm_first: bool = False, activation = relu self,
num_layers: int,
dims: int,
num_heads: int,
mlp_dims: Optional[int] = None,
dropout: float = 0.0,
norm_first: bool = False,
activation=relu,
): ):
super().__init__() super().__init__()
self.layers = [ self.layers = [
TransformerDecoderLayer(dims, num_heads, mlp_dims, dropout, norm_first, activation) TransformerDecoderLayer(
dims, num_heads, mlp_dims, dropout, norm_first, activation
)
for i in range(num_layers) for i in range(num_layers)
] ]
self.ln = LayerNorm(dims) self.ln = LayerNorm(dims)
@ -270,6 +288,7 @@ class Transformer(Module):
norm_first (bool): if ``True``, encoder and decoder layers will perform LayerNorms before norm_first (bool): if ``True``, encoder and decoder layers will perform LayerNorms before
other attention and feedforward operations, otherwise after. Default is``False``. other attention and feedforward operations, otherwise after. Default is``False``.
""" """
def __init__( def __init__(
self, self,
dims: int = 512, dims: int = 512,
@ -281,21 +300,33 @@ class Transformer(Module):
activation: Callable[[Any], Any] = relu, activation: Callable[[Any], Any] = relu,
custom_encoder: Optional[Any] = None, custom_encoder: Optional[Any] = None,
custom_decoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
norm_first: bool = False norm_first: bool = False,
): ):
super().__init__() super().__init__()
if custom_encoder is not None: if custom_encoder is not None:
self.encoder = custom_encoder self.encoder = custom_encoder
else: else:
self.encoder = TransformerEncoder( self.encoder = TransformerEncoder(
num_encoder_layers, dims, num_heads, mlp_dims, dropout, activation, norm_first num_encoder_layers,
dims,
num_heads,
mlp_dims,
dropout,
activation,
norm_first,
) )
if custom_decoder is not None: if custom_decoder is not None:
self.decoder = custom_decoder self.decoder = custom_decoder
else: else:
self.decoder = TransformerDecoder( self.decoder = TransformerDecoder(
num_decoder_layers, dims, num_heads, mlp_dims, dropout, activation, norm_first num_decoder_layers,
dims,
num_heads,
mlp_dims,
dropout,
activation,
norm_first,
) )
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask): def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):