mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Add argument to generate float16 npz
This commit is contained in:
parent
09e851499a
commit
b2a3782a96
@ -46,9 +46,11 @@ def replace_key(key: str) -> str:
|
|||||||
return key
|
return key
|
||||||
|
|
||||||
|
|
||||||
def convert(model_name):
|
def convert(model_name, half_precision=False):
|
||||||
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
|
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
|
||||||
weights = {replace_key(k): v.numpy() for k, v in model.state_dict().items()}
|
weights = {replace_key(k): v.numpy() for k, v in model.state_dict().items()}
|
||||||
|
if half_precision:
|
||||||
|
weights = {k: v.astype(np.float16) for k, v in weights.items()}
|
||||||
np.savez("weights.npz", **weights)
|
np.savez("weights.npz", **weights)
|
||||||
|
|
||||||
|
|
||||||
@ -63,5 +65,10 @@ if __name__ == "__main__":
|
|||||||
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
|
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
|
||||||
default="t5-small",
|
default="t5-small",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--half-precision",
|
||||||
|
action="store_true",
|
||||||
|
help="Convert weights to half precision (float16).",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert(args.model_name)
|
convert(args.model_name, args.half_precision)
|
||||||
|
4
t5/t5.py
4
t5/t5.py
@ -353,7 +353,7 @@ def load_model(model_config):
|
|||||||
print("Loading shape: ", weights_to_load_dict[key].shape)
|
print("Loading shape: ", weights_to_load_dict[key].shape)
|
||||||
model.update(tree_unflatten(weights_to_load))
|
model.update(tree_unflatten(weights_to_load))
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
tokenizer = T5Tokenizer.from_pretrained("t5-small", trust_remote_code=True)
|
tokenizer = T5Tokenizer.from_pretrained("t5-small", legacy=False)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
@ -368,7 +368,7 @@ if __name__ == "__main__":
|
|||||||
"--encode-only",
|
"--encode-only",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
default=False,
|
default=False,
|
||||||
help="Whether to decode or not",
|
help="Whether to decode or not. If true, will output last layer of encoder.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_tokens",
|
"--max_tokens",
|
||||||
|
Loading…
Reference in New Issue
Block a user