Remove prints

This commit is contained in:
Juarez Bochi 2023-12-17 08:52:54 -05:00
parent 152e85fade
commit 61fda57eba
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6

View File

@ -124,9 +124,6 @@ class MultiHeadAttention(nn.Module):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
# print(f"queries: {queries}, {queries.abs().sum()}")
# print(f"keys: {keys}, {keys.abs().sum()}")
# print(f"values: {values}, {values.abs().sum()}")
num_heads = self.num_heads
B, L, _ = queries.shape
@ -137,7 +134,6 @@ class MultiHeadAttention(nn.Module):
# Dimensions are [batch x num heads x sequence x hidden dim]
scores = queries @ keys
if mask is not None:
scores = scores + mask.astype(scores.dtype)
@ -147,7 +143,8 @@ class MultiHeadAttention(nn.Module):
scores += position_bias
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), position_bias
out = self.out_proj(values_hat)
return out, position_bias
@staticmethod
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):