From 688795c665c249ca1e8882fb89a59c4878a2484a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 18 Dec 2023 17:15:49 -0800 Subject: [PATCH] default to fp32 for now --- t5/t5.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/t5/t5.py b/t5/t5.py index b33258d5..11476da4 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -146,9 +146,6 @@ class RMSNorm(nn.Module): def __call__(self, x): t = x.dtype - if t == mx.float16: - x = x.astype(mx.float32) - x = mx.clip(x, a_min=-1e9, a_max=1e9) output = self._norm(x).astype(t) return self.weight * output @@ -406,7 +403,7 @@ if __name__ == "__main__": help="The model data type.", type=str, choices=["float16", "bfloat16", "float32"], - default="float16", + default="float32", ) parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")