mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Support snapshot_download for ModelScope (#1194)
* add MLX_USE_MODELSCOPE env * update * update snapshot_download * update * remove modelscope dependency and add import check * update * nits * fix --------- Co-authored-by: wangxingjun778 <jason@U-C7X6TX5G-2239.local> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
93c5cfd781
commit
514502da22
@ -7,6 +7,7 @@ import glob
|
|||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -16,7 +17,17 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
if os.getenv("MLXLM_USE_MODELSCOPE", "False").lower() == "true":
|
||||||
|
try:
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Please run `pip install modelscope` to activate the ModelScope."
|
||||||
|
)
|
||||||
|
else:
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from mlx.utils import tree_flatten, tree_reduce
|
from mlx.utils import tree_flatten, tree_reduce
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
@ -154,11 +165,12 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
|
|||||||
Path: The path to the model.
|
Path: The path to the model.
|
||||||
"""
|
"""
|
||||||
model_path = Path(path_or_hf_repo)
|
model_path = Path(path_or_hf_repo)
|
||||||
|
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
try:
|
try:
|
||||||
model_path = Path(
|
model_path = Path(
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
repo_id=path_or_hf_repo,
|
path_or_hf_repo,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
allow_patterns=[
|
allow_patterns=[
|
||||||
"*.json",
|
"*.json",
|
||||||
|
Loading…
Reference in New Issue
Block a user