mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
Pass config to all modules, fix ln
This commit is contained in:
parent
c0001a94f2
commit
62924d8135
31
t5/t5.py
31
t5/t5.py
@ -46,14 +46,14 @@ class LayerNorm(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(nn.Module):
|
class TransformerEncoderLayer(nn.Module):
|
||||||
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
mlp_dims = mlp_dims or dims * 4
|
mlp_dims = config.d_ff or config.d_model * 4
|
||||||
self.attention = nn.MultiHeadAttention(dims, num_heads)
|
self.attention = nn.MultiHeadAttention(config.d_model, config.num_heads)
|
||||||
self.ln1 = LayerNorm(dims)
|
self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.ln2 = LayerNorm(dims)
|
self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
|
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||||
self.linear2 = nn.Linear(mlp_dims, dims, bias=False)
|
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
|
||||||
|
|
||||||
def __call__(self, x, mask):
|
def __call__(self, x, mask):
|
||||||
y = self.ln1(x)
|
y = self.ln1(x)
|
||||||
@ -70,15 +70,13 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TransformerEncoder(nn.Module):
|
class TransformerEncoder(nn.Module):
|
||||||
def __init__(
|
def __init__(self, config: ModelArgs):
|
||||||
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = [
|
self.layers = [
|
||||||
TransformerEncoderLayer(dims, num_heads, mlp_dims)
|
TransformerEncoderLayer(config)
|
||||||
for _ in range(num_layers)
|
for _ in range(config.num_layers)
|
||||||
]
|
]
|
||||||
self.ln = LayerNorm(dims)
|
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
def __call__(self, x, mask):
|
def __call__(self, x, mask):
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
@ -91,12 +89,7 @@ class TransformerEncoder(nn.Module):
|
|||||||
class T5(nn.Module):
|
class T5(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
||||||
self.encoder = TransformerEncoder(
|
self.encoder = TransformerEncoder(config)
|
||||||
num_layers=config.num_layers,
|
|
||||||
dims=config.d_model,
|
|
||||||
num_heads=config.num_heads,
|
|
||||||
mlp_dims=config.d_ff,
|
|
||||||
)
|
|
||||||
# self.decoder = TransformerDecoder(config)
|
# self.decoder = TransformerDecoder(config)
|
||||||
# self.lm_head = OutputHead(config)
|
# self.lm_head = OutputHead(config)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user