From 61fda57eba57b7d23da4f2a21edfba5a9d5f5acb Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Sun, 17 Dec 2023 08:52:54 -0500 Subject: [PATCH] Remove prints --- t5/t5.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/t5/t5.py b/t5/t5.py index d061e225..a60b6246 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -124,9 +124,6 @@ class MultiHeadAttention(nn.Module): queries = self.query_proj(queries) keys = self.key_proj(keys) values = self.value_proj(values) - # print(f"queries: {queries}, {queries.abs().sum()}") - # print(f"keys: {keys}, {keys.abs().sum()}") - # print(f"values: {values}, {values.abs().sum()}") num_heads = self.num_heads B, L, _ = queries.shape @@ -137,7 +134,6 @@ class MultiHeadAttention(nn.Module): # Dimensions are [batch x num heads x sequence x hidden dim] scores = queries @ keys - if mask is not None: scores = scores + mask.astype(scores.dtype) @@ -147,7 +143,8 @@ class MultiHeadAttention(nn.Module): scores += position_bias scores = mx.softmax(scores, axis=-1) values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.out_proj(values_hat), position_bias + out = self.out_proj(values_hat) + return out, position_bias @staticmethod def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):