mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Use position bias in all layers
This commit is contained in:
parent
203f550ef9
commit
7e42349f4c
16
t5/t5.py
16
t5/t5.py
@ -120,7 +120,7 @@ class MultiHeadAttention(nn.Module):
|
||||
if has_relative_attention_bias:
|
||||
self.relative_attention_bias = RelativePositionBias(config)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None):
|
||||
def __call__(self, queries, keys, values, mask=None, position_bias=None):
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
@ -140,10 +140,11 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
if self.has_relative_attention_bias:
|
||||
position_bias = self.relative_attention_bias(L, S)
|
||||
if position_bias is not None:
|
||||
scores += position_bias
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.out_proj(values_hat)
|
||||
return self.out_proj(values_hat), position_bias
|
||||
|
||||
@staticmethod
|
||||
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
|
||||
@ -185,9 +186,11 @@ class TransformerEncoderLayer(nn.Module):
|
||||
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
|
||||
|
||||
def __call__(self, x, mask):
|
||||
def __call__(self, x, mask, position_bias=None):
|
||||
y = self.ln1(x)
|
||||
y = self.attention(y, y, y, mask)
|
||||
y, position_bias = self.attention(
|
||||
queries=y, keys=y, values=y, mask=mask, position_bias=position_bias
|
||||
)
|
||||
x = x + y
|
||||
|
||||
y = self.ln2(x)
|
||||
@ -196,7 +199,7 @@ class TransformerEncoderLayer(nn.Module):
|
||||
y = self.linear2(y)
|
||||
x = x + y
|
||||
|
||||
return x
|
||||
return x, position_bias
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
@ -209,8 +212,9 @@ class TransformerEncoder(nn.Module):
|
||||
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
|
||||
def __call__(self, x, mask):
|
||||
position_bias = None
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask)
|
||||
x, position_bias = layer(x, mask, position_bias=position_bias)
|
||||
x = self.ln(x)
|
||||
|
||||
return x
|
||||
|
Loading…
Reference in New Issue
Block a user