mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-29 22:01:17 +08:00
Feature Addition: Encoder-Decoder Transformer Architecture (#50)
* Implemented decoder-transformer-layer, decoder-transformer and introduce encoder-decoder transformer * added relu layer * add src, tgt, memory mask --------- Co-authored-by: rushyam <rushyam@rushyams-MacBook-Air.local>
This commit is contained in:
parent
dfbc52ce56
commit
2e126aeb7e
@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional
|
from typing import Optional, Any
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
@ -136,3 +136,85 @@ class TransformerEncoder(Module):
|
|||||||
x = self.ln(x)
|
x = self.ln(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoderLayer(Module):
|
||||||
|
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
|
||||||
|
super().__init__()
|
||||||
|
mlp_dims = mlp_dims or dims * 4
|
||||||
|
self.self_attention = MultiHeadAttention(dims, num_heads)
|
||||||
|
self.cross_attention = MultiHeadAttention(dims, num_heads)
|
||||||
|
self.ln1 = LayerNorm(dims)
|
||||||
|
self.ln2 = LayerNorm(dims)
|
||||||
|
self.ln3 = LayerNorm(dims)
|
||||||
|
self.linear1 = Linear(dims, mlp_dims)
|
||||||
|
self.linear2 = Linear(mlp_dims, dims)
|
||||||
|
|
||||||
|
def __call__(self, x, memory, x_mask, memory_mask):
|
||||||
|
y = self.ln1(x)
|
||||||
|
y = self.self_attention(y, y, y, x_mask)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
y = self.ln2(x)
|
||||||
|
y = self.cross_attention(x, memory, memory, memory_mask)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
y = self.ln3(x)
|
||||||
|
y = self.linear1(y)
|
||||||
|
y = mx.maximum(y, 0)
|
||||||
|
y = self.linear2(y)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoder(Module):
|
||||||
|
def __init__(
|
||||||
|
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = [
|
||||||
|
TransformerDecoderLayer(dims, num_heads, mlp_dims)
|
||||||
|
for i in range(num_layers)
|
||||||
|
]
|
||||||
|
self.ln = LayerNorm(dims)
|
||||||
|
|
||||||
|
def __call__(self, x, memory, x_mask, memory_mask):
|
||||||
|
for l in self.layers:
|
||||||
|
x = l(x, memory, x_mask, memory_mask)
|
||||||
|
x = self.ln(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dims: int = 512,
|
||||||
|
num_heads: int = 8,
|
||||||
|
num_encoder_layers: int = 6,
|
||||||
|
num_decoder_layers: int = 6,
|
||||||
|
mlp_dims: Optional[int] = None,
|
||||||
|
custom_encoder: Optional[Any] = None,
|
||||||
|
custom_decoder: Optional[Any] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if custom_encoder is not None:
|
||||||
|
self.encoder = custom_encoder
|
||||||
|
else:
|
||||||
|
self.encoder = TransformerEncoder(
|
||||||
|
num_encoder_layers, dims, num_heads, mlp_dims
|
||||||
|
)
|
||||||
|
|
||||||
|
if custom_decoder is not None:
|
||||||
|
self.decoder = custom_decoder
|
||||||
|
else:
|
||||||
|
self.decoder = TransformerDecoder(
|
||||||
|
num_decoder_layers, dims, num_heads, mlp_dims
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):
|
||||||
|
memory = self.encoder(src, src_mask)
|
||||||
|
output = self.decoder(tgt, memory, tgt_mask, memory_mask)
|
||||||
|
|
||||||
|
return output
|
||||||
|
Loading…
Reference in New Issue
Block a user