From 47060a8130f5a9489a046694e30633df05f28ab8 Mon Sep 17 00:00:00 2001 From: "M. Ali Bayram" Date: Tue, 23 Jul 2024 23:10:20 +0300 Subject: [PATCH] refactor: add force_download parameter to get_model_path function (#800) --- clip/convert.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/clip/convert.py b/clip/convert.py index a646f93f..29bac22e 100644 --- a/clip/convert.py +++ b/clip/convert.py @@ -63,7 +63,7 @@ def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None: ) -def get_model_path(path_or_hf_repo: str) -> Path: +def get_model_path(path_or_hf_repo: str, force_download: bool = False) -> Path: model_path = Path(path_or_hf_repo) if not model_path.exists(): model_path = Path( @@ -74,6 +74,7 @@ def get_model_path(path_or_hf_repo: str) -> Path: "*.json", "*.txt", ], + force_download=force_download, ) ) return model_path @@ -107,9 +108,15 @@ if __name__ == "__main__": type=str, default="float32", ) + parser.add_argument( + "-f", + "--force-download", + help="Force download the model from Hugging Face.", + action="store_true", + ) args = parser.parse_args() - torch_path = get_model_path(args.hf_repo) + torch_path = get_model_path(args.hf_repo, args.force_download) mlx_path = Path(args.mlx_path) mlx_path.mkdir(parents=True, exist_ok=True)