add transformer with dropout, fix transformer ffm, layernorm order

This commit is contained in:
Jyun1998 2023-12-15 03:26:17 +08:00
parent fb675de30d
commit 4bd3c02c00

View File

@ -7,7 +7,49 @@ import mlx.core as mx
from mlx.nn.layers.base import Module from mlx.nn.layers.base import Module
from mlx.nn.layers.linear import Linear from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import LayerNorm from mlx.nn.layers.normalization import LayerNorm
from mlx.nn.layers.dropout import Dropout
from mlx.nn.layers.positional_encoding import SinusoidalPositionalEncoding
class MyPosEncoding(SinusoidalPositionalEncoding):
def __init__(
self,
dims: int,
min_freq: float = 0.0001,
max_freq: float = 1,
scale: Optional[float] = None,
cos_first: bool = False,
full_turns: bool = False,
):
super().__init__(
dims,
min_freq=min_freq,
max_freq=max_freq,
scale=scale,
cos_first=cos_first,
full_turns=full_turns
)
self.dims = dims
def __call__(self, x):
seq_length = x.shape[1] # Assuming x.shape [batch_size, sequence_length, embedding_dim]
position = mx.arange(seq_length)[..., None] * self._sigmas
# Generate positional encodings
div_term = mx.exp(mx.arange(0, self.dims, 2) * -(math.log(10000.0) / self.dims))
sinusoid_inp = position * div_term
y = mx.zeros((seq_length, self.dims))
if self.cos_first:
y[:, 0::2] = mx.cos(sinusoid_inp)
y[:, 1::2] = mx.sin(sinusoid_inp)
else:
y[:, 0::2] = mx.sin(sinusoid_inp)
y[:, 1::2] = mx.cos(sinusoid_inp)
if self.scale != 1:
y = y * self.scale
return x + y
class MultiHeadAttention(Module): class MultiHeadAttention(Module):
"""Implements the scaled dot product attention with multiple heads. """Implements the scaled dot product attention with multiple heads.
@ -107,17 +149,41 @@ class TransformerEncoderLayer(Module):
self.linear2 = Linear(mlp_dims, dims) self.linear2 = Linear(mlp_dims, dims)
def __call__(self, x, mask): def __call__(self, x, mask):
y = self.ln1(x) y = self.attention(x, x, x, mask)
y = self.attention(y, y, y, mask) y = self.ln1(x + y)
x = x + y
y = self.ln2(x)
y = self.linear1(y) y = self.linear1(y)
y = mx.maximum(y, 0) y = mx.maximum(y, 0)
y = self.linear2(y) y = self.linear2(y)
x = x + y
return x y = self.ln2(x + y)
return y
class TransformerEncoderLayerWithDropout(Module):
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout_rate: float = 0.1):
super().__init__()
mlp_dims = mlp_dims or dims * 4
self.attention = MultiHeadAttention(dims, num_heads)
self.ln1 = LayerNorm(dims)
self.ln2 = LayerNorm(dims)
self.linear1 = Linear(dims, mlp_dims)
self.linear2 = Linear(mlp_dims, dims)
self.dropout1 = Dropout(dropout_rate)
self.dropout2 = Dropout(dropout_rate)
def __call__(self, x, mask):
y = self.attention(x, x, x, mask)
y = self.dropout1(y)
y = self.ln1(x + y)
y = self.linear1(y)
y = mx.maximum(y, 0)
y = self.linear2(y)
y = self.dropout2(y)
y = self.ln2(x + y)
return y
class TransformerEncoder(Module): class TransformerEncoder(Module):
@ -139,6 +205,26 @@ class TransformerEncoder(Module):
return x return x
class TransformerEncoderWithDropout(Module):
def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
):
super().__init__()
self.layers = [
TransformerEncoderLayerWithDropout(dims, num_heads, mlp_dims)
for i in range(num_layers)
]
self.ln = LayerNorm(dims)
def __call__(self, x, mask):
for l in self.layers:
x = l(x, mask)
x = self.ln(x)
return x
class TransformerDecoderLayer(Module): class TransformerDecoderLayer(Module):
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None): def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
super().__init__() super().__init__()
@ -152,21 +238,51 @@ class TransformerDecoderLayer(Module):
self.linear2 = Linear(mlp_dims, dims) self.linear2 = Linear(mlp_dims, dims)
def __call__(self, x, memory, x_mask, memory_mask): def __call__(self, x, memory, x_mask, memory_mask):
y = self.ln1(x) y = self.self_attention(x, x, x, x_mask)
y = self.self_attention(y, y, y, x_mask) x = self.ln1(x + y)
x = x + y
y = self.cross_attention(y, memory, memory, memory_mask)
x = self.ln1(x + y)
y = self.ln2(x) y = self.linear1(x)
y = self.cross_attention(x, memory, memory, memory_mask)
x = x + y
y = self.ln3(x)
y = self.linear1(y)
y = mx.maximum(y, 0) y = mx.maximum(y, 0)
y = self.linear2(y) y = self.linear2(y)
x = x + y y = self.ln3(x + y)
return x return y
class TransformerDecoderLayerWithDropout(Module):
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout_rate: float = 0.1):
super().__init__()
mlp_dims = mlp_dims or dims * 4
self.self_attention = MultiHeadAttention(dims, num_heads)
self.cross_attention = MultiHeadAttention(dims, num_heads)
self.ln1 = LayerNorm(dims)
self.ln2 = LayerNorm(dims)
self.ln3 = LayerNorm(dims)
self.linear1 = Linear(dims, mlp_dims)
self.linear2 = Linear(mlp_dims, dims)
self.dropout1 = Dropout(dropout_rate)
self.dropout2 = Dropout(dropout_rate)
self.dropout3 = Dropout(dropout_rate)
def __call__(self, x, memory, x_mask, memory_mask):
y = self.self_attention(x, x, x, x_mask)
y = self.dropout1(y)
x = self.ln1(x + y)
y = self.cross_attention(y, memory, memory, memory_mask)
y = self.dropout2(y)
x = self.ln1(x + y)
y = self.linear1(x)
y = mx.maximum(y, 0)
y = self.linear2(y)
y = self.dropout3(y)
y = self.ln3(x + y)
return y
class TransformerDecoder(Module): class TransformerDecoder(Module):
@ -188,6 +304,25 @@ class TransformerDecoder(Module):
return x return x
class TransformerDecoderWithDropout(Module):
def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
):
super().__init__()
self.layers = [
TransformerDecoderLayerWithDropout(dims, num_heads, mlp_dims)
for i in range(num_layers)
]
self.ln = LayerNorm(dims)
def __call__(self, x, memory, x_mask, memory_mask):
for l in self.layers:
x = l(x, memory, x_mask, memory_mask)
x = self.ln(x)
return x
class Transformer(Module): class Transformer(Module):
def __init__( def __init__(
self, self,