mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Add argument to generate float16 npz
This commit is contained in:
parent
09e851499a
commit
b2a3782a96
@ -46,9 +46,11 @@ def replace_key(key: str) -> str:
|
||||
return key
|
||||
|
||||
|
||||
def convert(model_name):
|
||||
def convert(model_name, half_precision=False):
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
|
||||
weights = {replace_key(k): v.numpy() for k, v in model.state_dict().items()}
|
||||
if half_precision:
|
||||
weights = {k: v.astype(np.float16) for k, v in weights.items()}
|
||||
np.savez("weights.npz", **weights)
|
||||
|
||||
|
||||
@ -63,5 +65,10 @@ if __name__ == "__main__":
|
||||
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
|
||||
default="t5-small",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--half-precision",
|
||||
action="store_true",
|
||||
help="Convert weights to half precision (float16).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert(args.model_name)
|
||||
convert(args.model_name, args.half_precision)
|
||||
|
4
t5/t5.py
4
t5/t5.py
@ -353,7 +353,7 @@ def load_model(model_config):
|
||||
print("Loading shape: ", weights_to_load_dict[key].shape)
|
||||
model.update(tree_unflatten(weights_to_load))
|
||||
mx.eval(model.parameters())
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small", trust_remote_code=True)
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small", legacy=False)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
@ -368,7 +368,7 @@ if __name__ == "__main__":
|
||||
"--encode-only",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to decode or not",
|
||||
help="Whether to decode or not. If true, will output last layer of encoder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_tokens",
|
||||
|
Loading…
Reference in New Issue
Block a user