Add argument to generate float16 npz

This commit is contained in:
Juarez Bochi 2023-12-18 08:21:20 -05:00
parent 09e851499a
commit b2a3782a96
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6
2 changed files with 11 additions and 4 deletions

View File

@ -46,9 +46,11 @@ def replace_key(key: str) -> str:
return key
def convert(model_name):
def convert(model_name, half_precision=False):
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
weights = {replace_key(k): v.numpy() for k, v in model.state_dict().items()}
if half_precision:
weights = {k: v.astype(np.float16) for k, v in weights.items()}
np.savez("weights.npz", **weights)
@ -63,5 +65,10 @@ if __name__ == "__main__":
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
default="t5-small",
)
parser.add_argument(
"--half-precision",
action="store_true",
help="Convert weights to half precision (float16).",
)
args = parser.parse_args()
convert(args.model_name)
convert(args.model_name, args.half_precision)

View File

@ -353,7 +353,7 @@ def load_model(model_config):
print("Loading shape: ", weights_to_load_dict[key].shape)
model.update(tree_unflatten(weights_to_load))
mx.eval(model.parameters())
tokenizer = T5Tokenizer.from_pretrained("t5-small", trust_remote_code=True)
tokenizer = T5Tokenizer.from_pretrained("t5-small", legacy=False)
return model, tokenizer
@ -368,7 +368,7 @@ if __name__ == "__main__":
"--encode-only",
action="store_true",
default=False,
help="Whether to decode or not",
help="Whether to decode or not. If true, will output last layer of encoder.",
)
parser.add_argument(
"--max_tokens",