mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
Remove prints
This commit is contained in:
parent
152e85fade
commit
61fda57eba
7
t5/t5.py
7
t5/t5.py
@ -124,9 +124,6 @@ class MultiHeadAttention(nn.Module):
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
# print(f"queries: {queries}, {queries.abs().sum()}")
|
||||
# print(f"keys: {keys}, {keys.abs().sum()}")
|
||||
# print(f"values: {values}, {values.abs().sum()}")
|
||||
|
||||
num_heads = self.num_heads
|
||||
B, L, _ = queries.shape
|
||||
@ -137,7 +134,6 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||
scores = queries @ keys
|
||||
|
||||
if mask is not None:
|
||||
scores = scores + mask.astype(scores.dtype)
|
||||
|
||||
@ -147,7 +143,8 @@ class MultiHeadAttention(nn.Module):
|
||||
scores += position_bias
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.out_proj(values_hat), position_bias
|
||||
out = self.out_proj(values_hat)
|
||||
return out, position_bias
|
||||
|
||||
@staticmethod
|
||||
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
|
||||
|
Loading…
Reference in New Issue
Block a user