diff --git a/t5/convert.py b/t5/convert.py index 35374724..c9377b5e 100644 --- a/t5/convert.py +++ b/t5/convert.py @@ -46,9 +46,11 @@ def replace_key(key: str) -> str: return key -def convert(model_name): +def convert(model_name, half_precision=False): model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto") weights = {replace_key(k): v.numpy() for k, v in model.state_dict().items()} + if half_precision: + weights = {k: v.astype(np.float16) for k, v in weights.items()} np.savez("weights.npz", **weights) @@ -63,5 +65,10 @@ if __name__ == "__main__": choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"], default="t5-small", ) + parser.add_argument( + "--half-precision", + action="store_true", + help="Convert weights to half precision (float16).", + ) args = parser.parse_args() - convert(args.model_name) + convert(args.model_name, args.half_precision) diff --git a/t5/t5.py b/t5/t5.py index 8fd2936a..6f53694d 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -353,7 +353,7 @@ def load_model(model_config): print("Loading shape: ", weights_to_load_dict[key].shape) model.update(tree_unflatten(weights_to_load)) mx.eval(model.parameters()) - tokenizer = T5Tokenizer.from_pretrained("t5-small", trust_remote_code=True) + tokenizer = T5Tokenizer.from_pretrained("t5-small", legacy=False) return model, tokenizer @@ -368,7 +368,7 @@ if __name__ == "__main__": "--encode-only", action="store_true", default=False, - help="Whether to decode or not", + help="Whether to decode or not. If true, will output last layer of encoder.", ) parser.add_argument( "--max_tokens",