mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Update t5 docs on variant support
This commit is contained in:
parent
930cd4d950
commit
6f4e33eff5
@ -25,6 +25,8 @@ The `<model>` can be any of the following:
|
|||||||
| t5-3b | 3 billion |
|
| t5-3b | 3 billion |
|
||||||
| t5-11b | 11 billion |
|
| t5-11b | 11 billion |
|
||||||
|
|
||||||
|
It also supports t5 variants, such as `google/flan-t5-small`, `google/flan-t5-base`, etc.
|
||||||
|
|
||||||
## Generate
|
## Generate
|
||||||
|
|
||||||
Generate text with:
|
Generate text with:
|
||||||
|
@ -62,19 +62,6 @@ if __name__ == "__main__":
|
|||||||
"--model",
|
"--model",
|
||||||
type=str,
|
type=str,
|
||||||
help="Name of the T5 model.",
|
help="Name of the T5 model.",
|
||||||
choices=[
|
|
||||||
"t5-small",
|
|
||||||
"t5-base",
|
|
||||||
"t5-large",
|
|
||||||
"t5-3b",
|
|
||||||
"t5-11b",
|
|
||||||
"google/flan-t5-small",
|
|
||||||
"google/flan-t5-base",
|
|
||||||
"google/flan-t5-large",
|
|
||||||
"google/flan-t5-xl",
|
|
||||||
"google/flan-t5-xxl",
|
|
||||||
"google/flan-t5-ul2",
|
|
||||||
],
|
|
||||||
default="t5-small",
|
default="t5-small",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
16
t5/t5.py
16
t5/t5.py
@ -384,7 +384,8 @@ def load_model(model_name: str, dtype: str = "float16"):
|
|||||||
config = T5Config.from_pretrained(args.model)
|
config = T5Config.from_pretrained(args.model)
|
||||||
dtype = getattr(mx, dtype)
|
dtype = getattr(mx, dtype)
|
||||||
model = T5(config)
|
model = T5(config)
|
||||||
weights = mx.load(f"{model_name}.npz")
|
file_name = model_name.replace("/", "-")
|
||||||
|
weights = mx.load(f"{file_name}.npz")
|
||||||
weights = tree_unflatten(list(weights.items()))
|
weights = tree_unflatten(list(weights.items()))
|
||||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||||
model.update(weights)
|
model.update(weights)
|
||||||
@ -398,19 +399,6 @@ if __name__ == "__main__":
|
|||||||
"--model",
|
"--model",
|
||||||
type=str,
|
type=str,
|
||||||
help="Name of the T5 model.",
|
help="Name of the T5 model.",
|
||||||
choices=[
|
|
||||||
"t5-small",
|
|
||||||
"t5-base",
|
|
||||||
"t5-large",
|
|
||||||
"t5-3b",
|
|
||||||
"t5-11b",
|
|
||||||
"google/flan-t5-small",
|
|
||||||
"google/flan-t5-base",
|
|
||||||
"google/flan-t5-large",
|
|
||||||
"google/flan-t5-xl",
|
|
||||||
"google/flan-t5-xxl",
|
|
||||||
"google/flan-t5-ul2",
|
|
||||||
],
|
|
||||||
default="t5-small",
|
default="t5-small",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
Loading…
Reference in New Issue
Block a user