Use position bias in all layers

This commit is contained in:
Juarez Bochi 2023-12-17 07:19:32 -05:00
parent 203f550ef9
commit 7e42349f4c
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6

View File

@ -120,7 +120,7 @@ class MultiHeadAttention(nn.Module):
if has_relative_attention_bias: if has_relative_attention_bias:
self.relative_attention_bias = RelativePositionBias(config) self.relative_attention_bias = RelativePositionBias(config)
def __call__(self, queries, keys, values, mask=None): def __call__(self, queries, keys, values, mask=None, position_bias=None):
queries = self.query_proj(queries) queries = self.query_proj(queries)
keys = self.key_proj(keys) keys = self.key_proj(keys)
values = self.value_proj(values) values = self.value_proj(values)
@ -140,10 +140,11 @@ class MultiHeadAttention(nn.Module):
if self.has_relative_attention_bias: if self.has_relative_attention_bias:
position_bias = self.relative_attention_bias(L, S) position_bias = self.relative_attention_bias(L, S)
if position_bias is not None:
scores += position_bias scores += position_bias
scores = mx.softmax(scores, axis=-1) scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat) return self.out_proj(values_hat), position_bias
@staticmethod @staticmethod
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32): def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
@ -185,9 +186,11 @@ class TransformerEncoderLayer(nn.Module):
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False) self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False) self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
def __call__(self, x, mask): def __call__(self, x, mask, position_bias=None):
y = self.ln1(x) y = self.ln1(x)
y = self.attention(y, y, y, mask) y, position_bias = self.attention(
queries=y, keys=y, values=y, mask=mask, position_bias=position_bias
)
x = x + y x = x + y
y = self.ln2(x) y = self.ln2(x)
@ -196,7 +199,7 @@ class TransformerEncoderLayer(nn.Module):
y = self.linear2(y) y = self.linear2(y)
x = x + y x = x + y
return x return x, position_bias
class TransformerEncoder(nn.Module): class TransformerEncoder(nn.Module):
@ -209,8 +212,9 @@ class TransformerEncoder(nn.Module):
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
def __call__(self, x, mask): def __call__(self, x, mask):
position_bias = None
for layer in self.layers: for layer in self.layers:
x = layer(x, mask) x, position_bias = layer(x, mask, position_bias=position_bias)
x = self.ln(x) x = self.ln(x)
return x return x