mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 04:56:41 +08:00
run precommit
This commit is contained in:
parent
297e69017c
commit
0e0557b756
@ -1,7 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import Any, Optional, Callable
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.activations import relu
|
||||
@ -133,7 +133,7 @@ class TransformerEncoderLayer(Module):
|
||||
y = self.dropout2(y)
|
||||
y = self.linear2(y)
|
||||
y = x + y
|
||||
|
||||
|
||||
else:
|
||||
y = self.attention(x, x, x, mask)
|
||||
y = self.dropout1(y)
|
||||
@ -144,17 +144,26 @@ class TransformerEncoderLayer(Module):
|
||||
y = self.dropout2(y)
|
||||
y = self.linear2(y)
|
||||
y = self.ln2(x + y)
|
||||
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class TransformerEncoder(Module):
|
||||
def __init__(
|
||||
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout: float = 0.0, norm_first: bool = False, activation = relu
|
||||
self,
|
||||
num_layers: int,
|
||||
dims: int,
|
||||
num_heads: int,
|
||||
mlp_dims: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
norm_first: bool = False,
|
||||
activation=relu,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerEncoderLayer(dims, num_heads, mlp_dims, dropout, norm_first, activation)
|
||||
TransformerEncoderLayer(
|
||||
dims, num_heads, mlp_dims, dropout, norm_first, activation
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
self.ln = LayerNorm(dims)
|
||||
@ -210,7 +219,7 @@ class TransformerDecoderLayer(Module):
|
||||
y = self.dropout3(y)
|
||||
y = self.linear2(y)
|
||||
x = x + y
|
||||
|
||||
|
||||
else:
|
||||
y = self.self_attention(x, x, x, x_mask)
|
||||
y = self.dropout1(y)
|
||||
@ -231,11 +240,20 @@ class TransformerDecoderLayer(Module):
|
||||
|
||||
class TransformerDecoder(Module):
|
||||
def __init__(
|
||||
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout: float = 0.0, norm_first: bool = False, activation = relu
|
||||
self,
|
||||
num_layers: int,
|
||||
dims: int,
|
||||
num_heads: int,
|
||||
mlp_dims: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
norm_first: bool = False,
|
||||
activation=relu,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerDecoderLayer(dims, num_heads, mlp_dims, dropout, norm_first, activation)
|
||||
TransformerDecoderLayer(
|
||||
dims, num_heads, mlp_dims, dropout, norm_first, activation
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
self.ln = LayerNorm(dims)
|
||||
@ -261,15 +279,16 @@ class Transformer(Module):
|
||||
num_heads (int): The number of heads in the multi-head attention models.
|
||||
num_encoder_layers (int): The number of sub-encoder-layers in the Transformer encoder.
|
||||
num_decoder_layers (int): The number of sub-decoder-layers in the Transformer decoder.
|
||||
mlp_dims (Optional[int]): The dimensionality of the feedforward network model in each Transformer layer.
|
||||
mlp_dims (Optional[int]): The dimensionality of the feedforward network model in each Transformer layer.
|
||||
Defaults to 4*dims if not provided.
|
||||
dropout (float): The dropout value for Transformer encoder/decoder.
|
||||
activation (Callable[[Any], Any]): the activation function of encoder/decoder intermediate layer
|
||||
custom_encoder (Optional[Any]): A custom encoder to replace the standard Transformer encoder.
|
||||
custom_decoder (Optional[Any]): A custom decoder to replace the standard Transformer decoder.
|
||||
custom_encoder (Optional[Any]): A custom encoder to replace the standard Transformer encoder.
|
||||
custom_decoder (Optional[Any]): A custom decoder to replace the standard Transformer decoder.
|
||||
norm_first (bool): if ``True``, encoder and decoder layers will perform LayerNorms before
|
||||
other attention and feedforward operations, otherwise after. Default is``False``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int = 512,
|
||||
@ -281,21 +300,33 @@ class Transformer(Module):
|
||||
activation: Callable[[Any], Any] = relu,
|
||||
custom_encoder: Optional[Any] = None,
|
||||
custom_decoder: Optional[Any] = None,
|
||||
norm_first: bool = False
|
||||
norm_first: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if custom_encoder is not None:
|
||||
self.encoder = custom_encoder
|
||||
else:
|
||||
self.encoder = TransformerEncoder(
|
||||
num_encoder_layers, dims, num_heads, mlp_dims, dropout, activation, norm_first
|
||||
num_encoder_layers,
|
||||
dims,
|
||||
num_heads,
|
||||
mlp_dims,
|
||||
dropout,
|
||||
activation,
|
||||
norm_first,
|
||||
)
|
||||
|
||||
if custom_decoder is not None:
|
||||
self.decoder = custom_decoder
|
||||
else:
|
||||
self.decoder = TransformerDecoder(
|
||||
num_decoder_layers, dims, num_heads, mlp_dims, dropout, activation, norm_first
|
||||
num_decoder_layers,
|
||||
dims,
|
||||
num_heads,
|
||||
mlp_dims,
|
||||
dropout,
|
||||
activation,
|
||||
norm_first,
|
||||
)
|
||||
|
||||
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):
|
||||
|
Loading…
Reference in New Issue
Block a user