mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Remove prints
This commit is contained in:
parent
152e85fade
commit
61fda57eba
7
t5/t5.py
7
t5/t5.py
@ -124,9 +124,6 @@ class MultiHeadAttention(nn.Module):
|
|||||||
queries = self.query_proj(queries)
|
queries = self.query_proj(queries)
|
||||||
keys = self.key_proj(keys)
|
keys = self.key_proj(keys)
|
||||||
values = self.value_proj(values)
|
values = self.value_proj(values)
|
||||||
# print(f"queries: {queries}, {queries.abs().sum()}")
|
|
||||||
# print(f"keys: {keys}, {keys.abs().sum()}")
|
|
||||||
# print(f"values: {values}, {values.abs().sum()}")
|
|
||||||
|
|
||||||
num_heads = self.num_heads
|
num_heads = self.num_heads
|
||||||
B, L, _ = queries.shape
|
B, L, _ = queries.shape
|
||||||
@ -137,7 +134,6 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||||
scores = queries @ keys
|
scores = queries @ keys
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
scores = scores + mask.astype(scores.dtype)
|
scores = scores + mask.astype(scores.dtype)
|
||||||
|
|
||||||
@ -147,7 +143,8 @@ class MultiHeadAttention(nn.Module):
|
|||||||
scores += position_bias
|
scores += position_bias
|
||||||
scores = mx.softmax(scores, axis=-1)
|
scores = mx.softmax(scores, axis=-1)
|
||||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.out_proj(values_hat), position_bias
|
out = self.out_proj(values_hat)
|
||||||
|
return out, position_bias
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
|
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
|
||||||
|
Loading…
Reference in New Issue
Block a user