From 6f4e33eff543503b4943a902f7d664890256ea5a Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Mon, 18 Dec 2023 22:59:36 -0500 Subject: [PATCH] Update t5 docs on variant support --- t5/README.md | 2 ++ t5/convert.py | 13 ------------- t5/t5.py | 16 ++-------------- 3 files changed, 4 insertions(+), 27 deletions(-) diff --git a/t5/README.md b/t5/README.md index 6c1c254d..289c7010 100644 --- a/t5/README.md +++ b/t5/README.md @@ -25,6 +25,8 @@ The `` can be any of the following: | t5-3b | 3 billion | | t5-11b | 11 billion | +It also supports t5 variants, such as `google/flan-t5-small`, `google/flan-t5-base`, etc. + ## Generate Generate text with: diff --git a/t5/convert.py b/t5/convert.py index 8e1d327d..71b009da 100644 --- a/t5/convert.py +++ b/t5/convert.py @@ -62,19 +62,6 @@ if __name__ == "__main__": "--model", type=str, help="Name of the T5 model.", - choices=[ - "t5-small", - "t5-base", - "t5-large", - "t5-3b", - "t5-11b", - "google/flan-t5-small", - "google/flan-t5-base", - "google/flan-t5-large", - "google/flan-t5-xl", - "google/flan-t5-xxl", - "google/flan-t5-ul2", - ], default="t5-small", ) args = parser.parse_args() diff --git a/t5/t5.py b/t5/t5.py index d4b1c1db..6dc5835d 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -384,7 +384,8 @@ def load_model(model_name: str, dtype: str = "float16"): config = T5Config.from_pretrained(args.model) dtype = getattr(mx, dtype) model = T5(config) - weights = mx.load(f"{model_name}.npz") + file_name = model_name.replace("/", "-") + weights = mx.load(f"{file_name}.npz") weights = tree_unflatten(list(weights.items())) weights = tree_map(lambda p: p.astype(dtype), weights) model.update(weights) @@ -398,19 +399,6 @@ if __name__ == "__main__": "--model", type=str, help="Name of the T5 model.", - choices=[ - "t5-small", - "t5-base", - "t5-large", - "t5-3b", - "t5-11b", - "google/flan-t5-small", - "google/flan-t5-base", - "google/flan-t5-large", - "google/flan-t5-xl", - "google/flan-t5-xxl", - "google/flan-t5-ul2", - ], default="t5-small", ) parser.add_argument(