mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
refactor: add force_download parameter to get_model_path function (#800)
This commit is contained in:
parent
3f337e0f0a
commit
47060a8130
@ -63,7 +63,7 @@ def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
|
||||
)
|
||||
|
||||
|
||||
def get_model_path(path_or_hf_repo: str) -> Path:
|
||||
def get_model_path(path_or_hf_repo: str, force_download: bool = False) -> Path:
|
||||
model_path = Path(path_or_hf_repo)
|
||||
if not model_path.exists():
|
||||
model_path = Path(
|
||||
@ -74,6 +74,7 @@ def get_model_path(path_or_hf_repo: str) -> Path:
|
||||
"*.json",
|
||||
"*.txt",
|
||||
],
|
||||
force_download=force_download,
|
||||
)
|
||||
)
|
||||
return model_path
|
||||
@ -107,9 +108,15 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
default="float32",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--force-download",
|
||||
help="Force download the model from Hugging Face.",
|
||||
action="store_true",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
torch_path = get_model_path(args.hf_repo)
|
||||
torch_path = get_model_path(args.hf_repo, args.force_download)
|
||||
mlx_path = Path(args.mlx_path)
|
||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user