diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 7b7b030d..3c368003 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -6,3 +6,5 @@ with a short description of your contribution(s) below. For example: - Jane Smith: Added the `foo` example. MLX Examples was developed with contributions from the following individuals: + +- Juarez Bochi: Added support for T5 models. diff --git a/t5/convert.py b/t5/convert.py index 71b009da..089d262d 100644 --- a/t5/convert.py +++ b/t5/convert.py @@ -44,13 +44,15 @@ def replace_key(key: str) -> str: return key -def convert(model_name): +def convert(model_name, dtype): + dtype = getattr(np, dtype) model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto") weights = { - replace_key(k): v.numpy().astype(np.float16) + replace_key(k): v.numpy().astype(dtype) for k, v in model.state_dict().items() } file_name = model_name.replace("/", "-") + print(f"Saving weights to {file_name}.npz") np.savez(f"{file_name}.npz", **weights) @@ -64,5 +66,12 @@ if __name__ == "__main__": help="Name of the T5 model.", default="t5-small", ) + parser.add_argument( + "--dtype", + help="The model data type.", + type=str, + choices=["float16", "float32"], + default="float32", + ) args = parser.parse_args() - convert(args.model) + convert(args.model, args.dtype) diff --git a/t5/t5.py b/t5/t5.py index 6dc5835d..f80c3cb3 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -337,7 +337,7 @@ class Tokenizer: self._tokenizer = T5Tokenizer.from_pretrained( args.model, legacy=False, - model_max_length=config.n_positions, + model_max_length=getattr(config, 'n_positions', 512) ) @property @@ -430,7 +430,7 @@ if __name__ == "__main__": help="The model data type.", type=str, choices=["float16", "bfloat16", "float32"], - default="float32", + default="bfloat16", ) parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")