Support snapshot_download for ModelScope (#1194)

* add MLX_USE_MODELSCOPE env

* update

* update snapshot_download

* update

* remove modelscope dependency and add import check

* update

* nits

* fix

---------

Co-authored-by: wangxingjun778 <jason@U-C7X6TX5G-2239.local>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Xingjun.Wang 2025-01-11 07:29:34 +08:00 committed by GitHub
parent 93c5cfd781
commit 514502da22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,6 +7,7 @@ import glob
import importlib import importlib
import json import json
import logging import logging
import os
import shutil import shutil
import time import time
from dataclasses import dataclass from dataclasses import dataclass
@ -16,7 +17,17 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from huggingface_hub import snapshot_download
if os.getenv("MLXLM_USE_MODELSCOPE", "False").lower() == "true":
try:
from modelscope import snapshot_download
except ImportError:
raise ImportError(
"Please run `pip install modelscope` to activate the ModelScope."
)
else:
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten, tree_reduce from mlx.utils import tree_flatten, tree_reduce
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@ -154,11 +165,12 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
Path: The path to the model. Path: The path to the model.
""" """
model_path = Path(path_or_hf_repo) model_path = Path(path_or_hf_repo)
if not model_path.exists(): if not model_path.exists():
try: try:
model_path = Path( model_path = Path(
snapshot_download( snapshot_download(
repo_id=path_or_hf_repo, path_or_hf_repo,
revision=revision, revision=revision,
allow_patterns=[ allow_patterns=[
"*.json", "*.json",