diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 2fc0446b..b9037295 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -7,6 +7,7 @@ import glob import importlib import json import logging +import os import shutil import time from dataclasses import dataclass @@ -16,7 +17,17 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, import mlx.core as mx import mlx.nn as nn -from huggingface_hub import snapshot_download + +if os.getenv("MLXLM_USE_MODELSCOPE", "False").lower() == "true": + try: + from modelscope import snapshot_download + except ImportError: + raise ImportError( + "Please run `pip install modelscope` to activate the ModelScope." + ) +else: + from huggingface_hub import snapshot_download + from mlx.utils import tree_flatten, tree_reduce from transformers import PreTrainedTokenizer @@ -154,11 +165,12 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path Path: The path to the model. """ model_path = Path(path_or_hf_repo) + if not model_path.exists(): try: model_path = Path( snapshot_download( - repo_id=path_or_hf_repo, + path_or_hf_repo, revision=revision, allow_patterns=[ "*.json",