Update t5 docs on variant support

This commit is contained in:
Juarez Bochi 2023-12-18 22:59:36 -05:00
parent 930cd4d950
commit 6f4e33eff5
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6
3 changed files with 4 additions and 27 deletions

View File

@ -25,6 +25,8 @@ The `<model>` can be any of the following:
| t5-3b | 3 billion |
| t5-11b | 11 billion |
It also supports t5 variants, such as `google/flan-t5-small`, `google/flan-t5-base`, etc.
## Generate
Generate text with:

View File

@ -62,19 +62,6 @@ if __name__ == "__main__":
"--model",
type=str,
help="Name of the T5 model.",
choices=[
"t5-small",
"t5-base",
"t5-large",
"t5-3b",
"t5-11b",
"google/flan-t5-small",
"google/flan-t5-base",
"google/flan-t5-large",
"google/flan-t5-xl",
"google/flan-t5-xxl",
"google/flan-t5-ul2",
],
default="t5-small",
)
args = parser.parse_args()

View File

@ -384,7 +384,8 @@ def load_model(model_name: str, dtype: str = "float16"):
config = T5Config.from_pretrained(args.model)
dtype = getattr(mx, dtype)
model = T5(config)
weights = mx.load(f"{model_name}.npz")
file_name = model_name.replace("/", "-")
weights = mx.load(f"{file_name}.npz")
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
model.update(weights)
@ -398,19 +399,6 @@ if __name__ == "__main__":
"--model",
type=str,
help="Name of the T5 model.",
choices=[
"t5-small",
"t5-base",
"t5-large",
"t5-3b",
"t5-11b",
"google/flan-t5-small",
"google/flan-t5-base",
"google/flan-t5-large",
"google/flan-t5-xl",
"google/flan-t5-xxl",
"google/flan-t5-ul2",
],
default="t5-small",
)
parser.add_argument(