diff --git a/t5/t5.py b/t5/t5.py index b33258d5..11476da4 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -146,9 +146,6 @@ class RMSNorm(nn.Module): def __call__(self, x): t = x.dtype - if t == mx.float16: - x = x.astype(mx.float32) - x = mx.clip(x, a_min=-1e9, a_max=1e9) output = self._norm(x).astype(t) return self.weight * output @@ -406,7 +403,7 @@ if __name__ == "__main__": help="The model data type.", type=str, choices=["float16", "bfloat16", "float32"], - default="float16", + default="float32", ) parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")