mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
default to fp32 for now
This commit is contained in:
parent
05a8464d78
commit
688795c665
5
t5/t5.py
5
t5/t5.py
@ -146,9 +146,6 @@ class RMSNorm(nn.Module):
|
||||
|
||||
def __call__(self, x):
|
||||
t = x.dtype
|
||||
if t == mx.float16:
|
||||
x = x.astype(mx.float32)
|
||||
x = mx.clip(x, a_min=-1e9, a_max=1e9)
|
||||
output = self._norm(x).astype(t)
|
||||
return self.weight * output
|
||||
|
||||
@ -406,7 +403,7 @@ if __name__ == "__main__":
|
||||
help="The model data type.",
|
||||
type=str,
|
||||
choices=["float16", "bfloat16", "float32"],
|
||||
default="float16",
|
||||
default="float32",
|
||||
)
|
||||
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
|
Loading…
Reference in New Issue
Block a user