mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 10:56:38 +08:00
Use position bias in decoder
This commit is contained in:
parent
7dcf2b688d
commit
daea1dcddf
15
t5/t5.py
15
t5/t5.py
@ -232,13 +232,13 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||||
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
|
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
|
||||||
|
|
||||||
def __call__(self, x, memory, x_mask, memory_mask):
|
def __call__(self, x, memory, x_mask, memory_mask, position_bias=None):
|
||||||
y = self.ln1(x)
|
y = self.ln1(x)
|
||||||
y = self.self_attention(y, y, y, x_mask)
|
y, position_bias = self.self_attention(y, y, y, x_mask, position_bias=position_bias)
|
||||||
x = x + y
|
x = x + y
|
||||||
|
|
||||||
y = self.ln2(x)
|
y = self.ln2(x)
|
||||||
y = self.cross_attention(x, memory, memory, memory_mask)
|
y, _ = self.cross_attention(x, memory, memory, memory_mask)
|
||||||
x = x + y
|
x = x + y
|
||||||
|
|
||||||
y = self.ln3(x)
|
y = self.ln3(x)
|
||||||
@ -247,7 +247,7 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
y = self.linear2(y)
|
y = self.linear2(y)
|
||||||
x = x + y
|
x = x + y
|
||||||
|
|
||||||
return x
|
return x, position_bias
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoder(nn.Module):
|
class TransformerDecoder(nn.Module):
|
||||||
@ -260,8 +260,11 @@ class TransformerDecoder(nn.Module):
|
|||||||
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
def __call__(self, x, memory, x_mask, memory_mask):
|
def __call__(self, x, memory, x_mask, memory_mask):
|
||||||
|
position_bias = None
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(x, memory, x_mask, memory_mask)
|
x, position_bias = layer(
|
||||||
|
x, memory, x_mask, memory_mask, position_bias=position_bias
|
||||||
|
)
|
||||||
x = self.ln(x)
|
x = self.ln(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
@ -298,7 +301,7 @@ class T5(nn.Module):
|
|||||||
mask = MultiHeadAttention.create_additive_causal_mask(decoder_n_tokens)
|
mask = MultiHeadAttention.create_additive_causal_mask(decoder_n_tokens)
|
||||||
mask = mask.astype(x.dtype)
|
mask = mask.astype(x.dtype)
|
||||||
|
|
||||||
y, cache = self.decoder(
|
y = self.decoder(
|
||||||
x=decoder_inputs, x_mask=mask, memory=y, memory_mask=None
|
x=decoder_inputs, x_mask=mask, memory=y, memory_mask=None
|
||||||
) # , cache)
|
) # , cache)
|
||||||
return self.lm_head(y), cache
|
return self.lm_head(y), cache
|
||||||
|
Loading…
Reference in New Issue
Block a user