refactor: add force_download parameter to get_model_path function (#800)

This commit is contained in:
M. Ali Bayram 2024-07-23 23:10:20 +03:00 committed by GitHub
parent 3f337e0f0a
commit 47060a8130
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -63,7 +63,7 @@ def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
)
def get_model_path(path_or_hf_repo: str) -> Path:
def get_model_path(path_or_hf_repo: str, force_download: bool = False) -> Path:
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(
@ -74,6 +74,7 @@ def get_model_path(path_or_hf_repo: str) -> Path:
"*.json",
"*.txt",
],
force_download=force_download,
)
)
return model_path
@ -107,9 +108,15 @@ if __name__ == "__main__":
type=str,
default="float32",
)
parser.add_argument(
"-f",
"--force-download",
help="Force download the model from Hugging Face.",
action="store_true",
)
args = parser.parse_args()
torch_path = get_model_path(args.hf_repo)
torch_path = get_model_path(args.hf_repo, args.force_download)
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)