Support snapshot_download for ModelScope (#1194)

* add MLX_USE_MODELSCOPE env

* update

* update snapshot_download

* update

* remove modelscope dependency and add import check

* update

* nits

* fix

---------

Co-authored-by: wangxingjun778 <jason@U-C7X6TX5G-2239.local>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Xingjun.Wang 2025-01-11 07:29:34 +08:00 committed by GitHub
parent 93c5cfd781
commit 514502da22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,6 +7,7 @@ import glob
import importlib
import json
import logging
import os
import shutil
import time
from dataclasses import dataclass
@ -16,7 +17,17 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
import mlx.core as mx
import mlx.nn as nn
if os.getenv("MLXLM_USE_MODELSCOPE", "False").lower() == "true":
try:
from modelscope import snapshot_download
except ImportError:
raise ImportError(
"Please run `pip install modelscope` to activate the ModelScope."
)
else:
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten, tree_reduce
from transformers import PreTrainedTokenizer
@ -154,11 +165,12 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
Path: The path to the model.
"""
model_path = Path(path_or_hf_repo)
if not model_path.exists():
try:
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
path_or_hf_repo,
revision=revision,
allow_patterns=[
"*.json",