mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Update t5 docs on variant support
This commit is contained in:
parent
930cd4d950
commit
6f4e33eff5
@ -25,6 +25,8 @@ The `<model>` can be any of the following:
|
||||
| t5-3b | 3 billion |
|
||||
| t5-11b | 11 billion |
|
||||
|
||||
It also supports t5 variants, such as `google/flan-t5-small`, `google/flan-t5-base`, etc.
|
||||
|
||||
## Generate
|
||||
|
||||
Generate text with:
|
||||
|
@ -62,19 +62,6 @@ if __name__ == "__main__":
|
||||
"--model",
|
||||
type=str,
|
||||
help="Name of the T5 model.",
|
||||
choices=[
|
||||
"t5-small",
|
||||
"t5-base",
|
||||
"t5-large",
|
||||
"t5-3b",
|
||||
"t5-11b",
|
||||
"google/flan-t5-small",
|
||||
"google/flan-t5-base",
|
||||
"google/flan-t5-large",
|
||||
"google/flan-t5-xl",
|
||||
"google/flan-t5-xxl",
|
||||
"google/flan-t5-ul2",
|
||||
],
|
||||
default="t5-small",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
16
t5/t5.py
16
t5/t5.py
@ -384,7 +384,8 @@ def load_model(model_name: str, dtype: str = "float16"):
|
||||
config = T5Config.from_pretrained(args.model)
|
||||
dtype = getattr(mx, dtype)
|
||||
model = T5(config)
|
||||
weights = mx.load(f"{model_name}.npz")
|
||||
file_name = model_name.replace("/", "-")
|
||||
weights = mx.load(f"{file_name}.npz")
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||
model.update(weights)
|
||||
@ -398,19 +399,6 @@ if __name__ == "__main__":
|
||||
"--model",
|
||||
type=str,
|
||||
help="Name of the T5 model.",
|
||||
choices=[
|
||||
"t5-small",
|
||||
"t5-base",
|
||||
"t5-large",
|
||||
"t5-3b",
|
||||
"t5-11b",
|
||||
"google/flan-t5-small",
|
||||
"google/flan-t5-base",
|
||||
"google/flan-t5-large",
|
||||
"google/flan-t5-xl",
|
||||
"google/flan-t5-xxl",
|
||||
"google/flan-t5-ul2",
|
||||
],
|
||||
default="t5-small",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
Loading…
Reference in New Issue
Block a user