diff --git a/t5/t5.py b/t5/t5.py index 9670df16..a59e5ab4 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -120,7 +120,7 @@ class MultiHeadAttention(nn.Module): if has_relative_attention_bias: self.relative_attention_bias = RelativePositionBias(config) - def __call__(self, queries, keys, values, mask=None): + def __call__(self, queries, keys, values, mask=None, position_bias=None): queries = self.query_proj(queries) keys = self.key_proj(keys) values = self.value_proj(values) @@ -140,10 +140,11 @@ class MultiHeadAttention(nn.Module): if self.has_relative_attention_bias: position_bias = self.relative_attention_bias(L, S) + if position_bias is not None: scores += position_bias scores = mx.softmax(scores, axis=-1) values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.out_proj(values_hat) + return self.out_proj(values_hat), position_bias @staticmethod def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32): @@ -185,9 +186,11 @@ class TransformerEncoderLayer(nn.Module): self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False) self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False) - def __call__(self, x, mask): + def __call__(self, x, mask, position_bias=None): y = self.ln1(x) - y = self.attention(y, y, y, mask) + y, position_bias = self.attention( + queries=y, keys=y, values=y, mask=mask, position_bias=position_bias + ) x = x + y y = self.ln2(x) @@ -196,7 +199,7 @@ class TransformerEncoderLayer(nn.Module): y = self.linear2(y) x = x + y - return x + return x, position_bias class TransformerEncoder(nn.Module): @@ -209,8 +212,9 @@ class TransformerEncoder(nn.Module): self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) def __call__(self, x, mask): + position_bias = None for layer in self.layers: - x = layer(x, mask) + x, position_bias = layer(x, mask, position_bias=position_bias) x = self.ln(x) return x