mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
default to fp32 for now
This commit is contained in:
parent
05a8464d78
commit
688795c665
5
t5/t5.py
5
t5/t5.py
@ -146,9 +146,6 @@ class RMSNorm(nn.Module):
|
|||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
t = x.dtype
|
t = x.dtype
|
||||||
if t == mx.float16:
|
|
||||||
x = x.astype(mx.float32)
|
|
||||||
x = mx.clip(x, a_min=-1e9, a_max=1e9)
|
|
||||||
output = self._norm(x).astype(t)
|
output = self._norm(x).astype(t)
|
||||||
return self.weight * output
|
return self.weight * output
|
||||||
|
|
||||||
@ -406,7 +403,7 @@ if __name__ == "__main__":
|
|||||||
help="The model data type.",
|
help="The model data type.",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["float16", "bfloat16", "float32"],
|
choices=["float16", "bfloat16", "float32"],
|
||||||
default="float16",
|
default="float32",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||||
|
Loading…
Reference in New Issue
Block a user