mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 09:48:54 +08:00
T5: Change default dtype to bfloat16 (#147)
* T5: Change default to bfloat16 * Add myself to contributors * t5: Change convert.py default to float32
This commit is contained in:
@@ -44,13 +44,15 @@ def replace_key(key: str) -> str:
|
||||
return key
|
||||
|
||||
|
||||
def convert(model_name):
|
||||
def convert(model_name, dtype):
|
||||
dtype = getattr(np, dtype)
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
|
||||
weights = {
|
||||
replace_key(k): v.numpy().astype(np.float16)
|
||||
replace_key(k): v.numpy().astype(dtype)
|
||||
for k, v in model.state_dict().items()
|
||||
}
|
||||
file_name = model_name.replace("/", "-")
|
||||
print(f"Saving weights to {file_name}.npz")
|
||||
np.savez(f"{file_name}.npz", **weights)
|
||||
|
||||
|
||||
@@ -64,5 +66,12 @@ if __name__ == "__main__":
|
||||
help="Name of the T5 model.",
|
||||
default="t5-small",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
help="The model data type.",
|
||||
type=str,
|
||||
choices=["float16", "float32"],
|
||||
default="float32",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert(args.model)
|
||||
convert(args.model, args.dtype)
|
||||
|
||||
Reference in New Issue
Block a user