diff --git a/t5/t5.py b/t5/t5.py index 155393bb..6e3374a6 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -125,8 +125,8 @@ class MultiHeadAttention(nn.Module): values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) # Dimensions are [batch x num heads x sequence x hidden dim] - scale = math.sqrt(1 / queries.shape[-1]) - scores = (queries * scale) @ keys + scores = queries @ keys + if mask is not None: scores = scores + mask.astype(scores.dtype) @@ -274,12 +274,14 @@ class T5(nn.Module): ) -> tuple[mx.array, mx.array]: x = self.wte(inputs) + + y = self.encoder(x, mask=None) #, cache) + mask = None if x.shape[1] > 1: mask = MultiHeadAttention.create_additive_causal_mask(x.shape[1]) mask = mask.astype(x.dtype) - y = self.encoder(x, mask) #, cache) # y, cache = self.decoder(x, mask, cache) # return self.lm_head(y), cache return y #, cache