diff --git a/t5/t5.py b/t5/t5.py index 2d736fbe..bca34645 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -125,11 +125,12 @@ class MultiHeadAttention(nn.Module): values = mx.concatenate([value_cache, values], axis=2) # Dimensions are [batch x num heads x sequence x hidden dim] - scores = queries @ keys + queries = queries.astype(mx.float32) + scores = queries @ keys.astype(mx.float32) if mask is not None: scores = scores + mask.astype(scores.dtype) - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(values.dtype) values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) return self.out_proj(values_hat), (keys, values) @@ -144,7 +145,11 @@ class RMSNorm(nn.Module): return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) def __call__(self, x): - output = self._norm(x.astype(mx.float32)).astype(x.dtype) + t = x.dtype + if t == mx.float16: + x = x.astype(mx.float32) + x = mx.clip(x, a_min=-1e6, a_max=1e6) + output = self._norm(x).astype(t) return self.weight * output @@ -290,6 +295,7 @@ class T5(nn.Module): y, cache = self.decoder( inputs, memory=memory, mask=mask, memory_mask=None, cache=cache ) + y = y.astype(mx.float32) if self.tie_word_embeddings: y *= self.model_dim**-0.5 return self.lm_head(y), cache @@ -409,7 +415,7 @@ if __name__ == "__main__": mx.random.seed(args.seed) - model, tokenizer = load_model(args.model) + model, tokenizer = load_model(args.model, args.dtype) if args.encode_only: print("[INFO] Encoding with T5...", flush=True)