T5: Change default to bfloat16

This commit is contained in:
Juarez Bochi 2023-12-19 14:53:07 -05:00
parent 10a7b99e83
commit 6d5d8c6bca
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6
2 changed files with 14 additions and 5 deletions

View File

@ -44,13 +44,15 @@ def replace_key(key: str) -> str:
return key
def convert(model_name):
def convert(model_name, dtype):
dtype = getattr(np, dtype)
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
weights = {
replace_key(k): v.numpy().astype(np.float16)
replace_key(k): v.numpy().astype(dtype)
for k, v in model.state_dict().items()
}
file_name = model_name.replace("/", "-")
print(f"Saving weights to {file_name}.npz")
np.savez(f"{file_name}.npz", **weights)
@ -64,5 +66,12 @@ if __name__ == "__main__":
help="Name of the T5 model.",
default="t5-small",
)
parser.add_argument(
"--dtype",
help="The model data type.",
type=str,
choices=["float16", "float32"],
default="float16",
)
args = parser.parse_args()
convert(args.model)
convert(args.model, args.dtype)

View File

@ -337,7 +337,7 @@ class Tokenizer:
self._tokenizer = T5Tokenizer.from_pretrained(
args.model,
legacy=False,
model_max_length=config.n_positions,
model_max_length=getattr(config, 'n_positions', 512)
)
@property
@ -430,7 +430,7 @@ if __name__ == "__main__":
help="The model data type.",
type=str,
choices=["float16", "bfloat16", "float32"],
default="float32",
default="bfloat16",
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")