default to fp32 for now

This commit is contained in:
Awni Hannun 2023-12-18 17:15:49 -08:00
parent 05a8464d78
commit 688795c665

View File

@ -146,9 +146,6 @@ class RMSNorm(nn.Module):
def __call__(self, x):
t = x.dtype
if t == mx.float16:
x = x.astype(mx.float32)
x = mx.clip(x, a_min=-1e9, a_max=1e9)
output = self._norm(x).astype(t)
return self.weight * output
@ -406,7 +403,7 @@ if __name__ == "__main__":
help="The model data type.",
type=str,
choices=["float16", "bfloat16", "float32"],
default="float16",
default="float32",
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")