No scaling, no encoder mask

This commit is contained in:
Juarez Bochi 2023-12-16 14:24:13 -05:00
parent 64e7eaccb8
commit d12db65eeb
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6

View File

@ -125,8 +125,8 @@ class MultiHeadAttention(nn.Module):
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys
scores = queries @ keys
if mask is not None:
scores = scores + mask.astype(scores.dtype)
@ -274,12 +274,14 @@ class T5(nn.Module):
) -> tuple[mx.array, mx.array]:
x = self.wte(inputs)
y = self.encoder(x, mask=None) #, cache)
mask = None
if x.shape[1] > 1:
mask = MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype)
y = self.encoder(x, mask) #, cache)
# y, cache = self.decoder(x, mask, cache)
# return self.lm_head(y), cache
return y #, cache