diff --git a/t5/t5.py b/t5/t5.py index d061e225..a60b6246 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -124,9 +124,6 @@ class MultiHeadAttention(nn.Module): queries = self.query_proj(queries) keys = self.key_proj(keys) values = self.value_proj(values) - # print(f"queries: {queries}, {queries.abs().sum()}") - # print(f"keys: {keys}, {keys.abs().sum()}") - # print(f"values: {values}, {values.abs().sum()}") num_heads = self.num_heads B, L, _ = queries.shape @@ -137,7 +134,6 @@ class MultiHeadAttention(nn.Module): # Dimensions are [batch x num heads x sequence x hidden dim] scores = queries @ keys - if mask is not None: scores = scores + mask.astype(scores.dtype) @@ -147,7 +143,8 @@ class MultiHeadAttention(nn.Module): scores += position_bias scores = mx.softmax(scores, axis=-1) values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.out_proj(values_hat), position_bias + out = self.out_proj(values_hat) + return out, position_bias @staticmethod def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):