Feature Addition: Encoder-Decoder Transformer Architecture (#50)

* Implemented decoder-transformer-layer, decoder-transformer  and introduce encoder-decoder transformer

* added relu layer

* add src, tgt, memory mask

---------

Co-authored-by: rushyam <rushyam@rushyams-MacBook-Air.local>
This commit is contained in:
rushyam 2023-12-07 21:07:36 +05:30 committed by GitHub
parent dfbc52ce56
commit 2e126aeb7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import math import math
from typing import Optional from typing import Optional, Any
import mlx.core as mx import mlx.core as mx
from mlx.nn.layers.base import Module from mlx.nn.layers.base import Module
@ -136,3 +136,85 @@ class TransformerEncoder(Module):
x = self.ln(x) x = self.ln(x)
return x return x
class TransformerDecoderLayer(Module):
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
super().__init__()
mlp_dims = mlp_dims or dims * 4
self.self_attention = MultiHeadAttention(dims, num_heads)
self.cross_attention = MultiHeadAttention(dims, num_heads)
self.ln1 = LayerNorm(dims)
self.ln2 = LayerNorm(dims)
self.ln3 = LayerNorm(dims)
self.linear1 = Linear(dims, mlp_dims)
self.linear2 = Linear(mlp_dims, dims)
def __call__(self, x, memory, x_mask, memory_mask):
y = self.ln1(x)
y = self.self_attention(y, y, y, x_mask)
x = x + y
y = self.ln2(x)
y = self.cross_attention(x, memory, memory, memory_mask)
x = x + y
y = self.ln3(x)
y = self.linear1(y)
y = mx.maximum(y, 0)
y = self.linear2(y)
x = x + y
return x
class TransformerDecoder(Module):
def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
):
super().__init__()
self.layers = [
TransformerDecoderLayer(dims, num_heads, mlp_dims)
for i in range(num_layers)
]
self.ln = LayerNorm(dims)
def __call__(self, x, memory, x_mask, memory_mask):
for l in self.layers:
x = l(x, memory, x_mask, memory_mask)
x = self.ln(x)
return x
class Transformer(Module):
def __init__(
self,
dims: int = 512,
num_heads: int = 8,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
mlp_dims: Optional[int] = None,
custom_encoder: Optional[Any] = None,
custom_decoder: Optional[Any] = None,
):
super().__init__()
if custom_encoder is not None:
self.encoder = custom_encoder
else:
self.encoder = TransformerEncoder(
num_encoder_layers, dims, num_heads, mlp_dims
)
if custom_decoder is not None:
self.decoder = custom_decoder
else:
self.decoder = TransformerDecoder(
num_decoder_layers, dims, num_heads, mlp_dims
)
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):
memory = self.encoder(src, src_mask)
output = self.decoder(tgt, memory, tgt_mask, memory_mask)
return output