run precommit

This commit is contained in:
junwoo-yun 2023-12-25 07:50:54 +08:00
parent 297e69017c
commit 0e0557b756

View File

@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc.
import math
from typing import Any, Optional, Callable
from typing import Any, Callable, Optional
import mlx.core as mx
from mlx.nn.layers.activations import relu
@ -150,11 +150,20 @@ class TransformerEncoderLayer(Module):
class TransformerEncoder(Module):
def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout: float = 0.0, norm_first: bool = False, activation = relu
self,
num_layers: int,
dims: int,
num_heads: int,
mlp_dims: Optional[int] = None,
dropout: float = 0.0,
norm_first: bool = False,
activation=relu,
):
super().__init__()
self.layers = [
TransformerEncoderLayer(dims, num_heads, mlp_dims, dropout, norm_first, activation)
TransformerEncoderLayer(
dims, num_heads, mlp_dims, dropout, norm_first, activation
)
for i in range(num_layers)
]
self.ln = LayerNorm(dims)
@ -231,11 +240,20 @@ class TransformerDecoderLayer(Module):
class TransformerDecoder(Module):
def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout: float = 0.0, norm_first: bool = False, activation = relu
self,
num_layers: int,
dims: int,
num_heads: int,
mlp_dims: Optional[int] = None,
dropout: float = 0.0,
norm_first: bool = False,
activation=relu,
):
super().__init__()
self.layers = [
TransformerDecoderLayer(dims, num_heads, mlp_dims, dropout, norm_first, activation)
TransformerDecoderLayer(
dims, num_heads, mlp_dims, dropout, norm_first, activation
)
for i in range(num_layers)
]
self.ln = LayerNorm(dims)
@ -270,6 +288,7 @@ class Transformer(Module):
norm_first (bool): if ``True``, encoder and decoder layers will perform LayerNorms before
other attention and feedforward operations, otherwise after. Default is``False``.
"""
def __init__(
self,
dims: int = 512,
@ -281,21 +300,33 @@ class Transformer(Module):
activation: Callable[[Any], Any] = relu,
custom_encoder: Optional[Any] = None,
custom_decoder: Optional[Any] = None,
norm_first: bool = False
norm_first: bool = False,
):
super().__init__()
if custom_encoder is not None:
self.encoder = custom_encoder
else:
self.encoder = TransformerEncoder(
num_encoder_layers, dims, num_heads, mlp_dims, dropout, activation, norm_first
num_encoder_layers,
dims,
num_heads,
mlp_dims,
dropout,
activation,
norm_first,
)
if custom_decoder is not None:
self.decoder = custom_decoder
else:
self.decoder = TransformerDecoder(
num_decoder_layers, dims, num_heads, mlp_dims, dropout, activation, norm_first
num_decoder_layers,
dims,
num_heads,
mlp_dims,
dropout,
activation,
norm_first,
)
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):