mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
add transformer with dropout, fix transformer ffm, layernorm order
This commit is contained in:
parent
fb675de30d
commit
4bd3c02c00
@ -7,7 +7,49 @@ import mlx.core as mx
|
|||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
from mlx.nn.layers.linear import Linear
|
from mlx.nn.layers.linear import Linear
|
||||||
from mlx.nn.layers.normalization import LayerNorm
|
from mlx.nn.layers.normalization import LayerNorm
|
||||||
|
from mlx.nn.layers.dropout import Dropout
|
||||||
|
from mlx.nn.layers.positional_encoding import SinusoidalPositionalEncoding
|
||||||
|
|
||||||
|
class MyPosEncoding(SinusoidalPositionalEncoding):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dims: int,
|
||||||
|
min_freq: float = 0.0001,
|
||||||
|
max_freq: float = 1,
|
||||||
|
scale: Optional[float] = None,
|
||||||
|
cos_first: bool = False,
|
||||||
|
full_turns: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
dims,
|
||||||
|
min_freq=min_freq,
|
||||||
|
max_freq=max_freq,
|
||||||
|
scale=scale,
|
||||||
|
cos_first=cos_first,
|
||||||
|
full_turns=full_turns
|
||||||
|
)
|
||||||
|
self.dims = dims
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
seq_length = x.shape[1] # Assuming x.shape [batch_size, sequence_length, embedding_dim]
|
||||||
|
position = mx.arange(seq_length)[..., None] * self._sigmas
|
||||||
|
|
||||||
|
# Generate positional encodings
|
||||||
|
div_term = mx.exp(mx.arange(0, self.dims, 2) * -(math.log(10000.0) / self.dims))
|
||||||
|
sinusoid_inp = position * div_term
|
||||||
|
|
||||||
|
y = mx.zeros((seq_length, self.dims))
|
||||||
|
if self.cos_first:
|
||||||
|
y[:, 0::2] = mx.cos(sinusoid_inp)
|
||||||
|
y[:, 1::2] = mx.sin(sinusoid_inp)
|
||||||
|
else:
|
||||||
|
y[:, 0::2] = mx.sin(sinusoid_inp)
|
||||||
|
y[:, 1::2] = mx.cos(sinusoid_inp)
|
||||||
|
|
||||||
|
if self.scale != 1:
|
||||||
|
y = y * self.scale
|
||||||
|
|
||||||
|
return x + y
|
||||||
|
|
||||||
class MultiHeadAttention(Module):
|
class MultiHeadAttention(Module):
|
||||||
"""Implements the scaled dot product attention with multiple heads.
|
"""Implements the scaled dot product attention with multiple heads.
|
||||||
@ -107,17 +149,41 @@ class TransformerEncoderLayer(Module):
|
|||||||
self.linear2 = Linear(mlp_dims, dims)
|
self.linear2 = Linear(mlp_dims, dims)
|
||||||
|
|
||||||
def __call__(self, x, mask):
|
def __call__(self, x, mask):
|
||||||
y = self.ln1(x)
|
y = self.attention(x, x, x, mask)
|
||||||
y = self.attention(y, y, y, mask)
|
y = self.ln1(x + y)
|
||||||
x = x + y
|
|
||||||
|
|
||||||
y = self.ln2(x)
|
|
||||||
y = self.linear1(y)
|
y = self.linear1(y)
|
||||||
y = mx.maximum(y, 0)
|
y = mx.maximum(y, 0)
|
||||||
y = self.linear2(y)
|
y = self.linear2(y)
|
||||||
x = x + y
|
|
||||||
|
|
||||||
return x
|
y = self.ln2(x + y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoderLayerWithDropout(Module):
|
||||||
|
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout_rate: float = 0.1):
|
||||||
|
super().__init__()
|
||||||
|
mlp_dims = mlp_dims or dims * 4
|
||||||
|
self.attention = MultiHeadAttention(dims, num_heads)
|
||||||
|
self.ln1 = LayerNorm(dims)
|
||||||
|
self.ln2 = LayerNorm(dims)
|
||||||
|
self.linear1 = Linear(dims, mlp_dims)
|
||||||
|
self.linear2 = Linear(mlp_dims, dims)
|
||||||
|
self.dropout1 = Dropout(dropout_rate)
|
||||||
|
self.dropout2 = Dropout(dropout_rate)
|
||||||
|
|
||||||
|
def __call__(self, x, mask):
|
||||||
|
y = self.attention(x, x, x, mask)
|
||||||
|
y = self.dropout1(y)
|
||||||
|
y = self.ln1(x + y)
|
||||||
|
|
||||||
|
y = self.linear1(y)
|
||||||
|
y = mx.maximum(y, 0)
|
||||||
|
y = self.linear2(y)
|
||||||
|
y = self.dropout2(y)
|
||||||
|
|
||||||
|
y = self.ln2(x + y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoder(Module):
|
class TransformerEncoder(Module):
|
||||||
@ -139,6 +205,26 @@ class TransformerEncoder(Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoderWithDropout(Module):
|
||||||
|
def __init__(
|
||||||
|
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = [
|
||||||
|
TransformerEncoderLayerWithDropout(dims, num_heads, mlp_dims)
|
||||||
|
for i in range(num_layers)
|
||||||
|
]
|
||||||
|
self.ln = LayerNorm(dims)
|
||||||
|
|
||||||
|
def __call__(self, x, mask):
|
||||||
|
for l in self.layers:
|
||||||
|
x = l(x, mask)
|
||||||
|
x = self.ln(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoderLayer(Module):
|
class TransformerDecoderLayer(Module):
|
||||||
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
|
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -152,21 +238,51 @@ class TransformerDecoderLayer(Module):
|
|||||||
self.linear2 = Linear(mlp_dims, dims)
|
self.linear2 = Linear(mlp_dims, dims)
|
||||||
|
|
||||||
def __call__(self, x, memory, x_mask, memory_mask):
|
def __call__(self, x, memory, x_mask, memory_mask):
|
||||||
y = self.ln1(x)
|
y = self.self_attention(x, x, x, x_mask)
|
||||||
y = self.self_attention(y, y, y, x_mask)
|
x = self.ln1(x + y)
|
||||||
x = x + y
|
|
||||||
|
|
||||||
y = self.ln2(x)
|
y = self.cross_attention(y, memory, memory, memory_mask)
|
||||||
y = self.cross_attention(x, memory, memory, memory_mask)
|
x = self.ln1(x + y)
|
||||||
x = x + y
|
|
||||||
|
|
||||||
y = self.ln3(x)
|
y = self.linear1(x)
|
||||||
y = self.linear1(y)
|
|
||||||
y = mx.maximum(y, 0)
|
y = mx.maximum(y, 0)
|
||||||
y = self.linear2(y)
|
y = self.linear2(y)
|
||||||
x = x + y
|
y = self.ln3(x + y)
|
||||||
|
|
||||||
return x
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoderLayerWithDropout(Module):
|
||||||
|
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout_rate: float = 0.1):
|
||||||
|
super().__init__()
|
||||||
|
mlp_dims = mlp_dims or dims * 4
|
||||||
|
self.self_attention = MultiHeadAttention(dims, num_heads)
|
||||||
|
self.cross_attention = MultiHeadAttention(dims, num_heads)
|
||||||
|
self.ln1 = LayerNorm(dims)
|
||||||
|
self.ln2 = LayerNorm(dims)
|
||||||
|
self.ln3 = LayerNorm(dims)
|
||||||
|
self.linear1 = Linear(dims, mlp_dims)
|
||||||
|
self.linear2 = Linear(mlp_dims, dims)
|
||||||
|
self.dropout1 = Dropout(dropout_rate)
|
||||||
|
self.dropout2 = Dropout(dropout_rate)
|
||||||
|
self.dropout3 = Dropout(dropout_rate)
|
||||||
|
|
||||||
|
def __call__(self, x, memory, x_mask, memory_mask):
|
||||||
|
y = self.self_attention(x, x, x, x_mask)
|
||||||
|
y = self.dropout1(y)
|
||||||
|
x = self.ln1(x + y)
|
||||||
|
|
||||||
|
y = self.cross_attention(y, memory, memory, memory_mask)
|
||||||
|
y = self.dropout2(y)
|
||||||
|
x = self.ln1(x + y)
|
||||||
|
|
||||||
|
y = self.linear1(x)
|
||||||
|
y = mx.maximum(y, 0)
|
||||||
|
y = self.linear2(y)
|
||||||
|
y = self.dropout3(y)
|
||||||
|
y = self.ln3(x + y)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoder(Module):
|
class TransformerDecoder(Module):
|
||||||
@ -188,6 +304,25 @@ class TransformerDecoder(Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoderWithDropout(Module):
|
||||||
|
def __init__(
|
||||||
|
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = [
|
||||||
|
TransformerDecoderLayerWithDropout(dims, num_heads, mlp_dims)
|
||||||
|
for i in range(num_layers)
|
||||||
|
]
|
||||||
|
self.ln = LayerNorm(dims)
|
||||||
|
|
||||||
|
def __call__(self, x, memory, x_mask, memory_mask):
|
||||||
|
for l in self.layers:
|
||||||
|
x = l(x, memory, x_mask, memory_mask)
|
||||||
|
x = self.ln(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Transformer(Module):
|
class Transformer(Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user