diff --git a/t5/t5.py b/t5/t5.py index bca34645..b33258d5 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -125,12 +125,12 @@ class MultiHeadAttention(nn.Module): values = mx.concatenate([value_cache, values], axis=2) # Dimensions are [batch x num heads x sequence x hidden dim] - queries = queries.astype(mx.float32) - scores = queries @ keys.astype(mx.float32) + queries = queries + scores = queries @ keys if mask is not None: scores = scores + mask.astype(scores.dtype) - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(values.dtype) + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) return self.out_proj(values_hat), (keys, values) @@ -148,7 +148,7 @@ class RMSNorm(nn.Module): t = x.dtype if t == mx.float16: x = x.astype(mx.float32) - x = mx.clip(x, a_min=-1e6, a_max=1e6) + x = mx.clip(x, a_min=-1e9, a_max=1e9) output = self._norm(x).astype(t) return self.weight * output @@ -295,7 +295,6 @@ class T5(nn.Module): y, cache = self.decoder( inputs, memory=memory, mask=mask, memory_mask=None, cache=cache ) - y = y.astype(mx.float32) if self.tie_word_embeddings: y *= self.model_dim**-0.5 return self.lm_head(y), cache