diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index e604b09c..670af8f6 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -16,11 +16,13 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, import mlx.core as mx import mlx.nn as nn -if os.getenv('MLX_USE_MODELSCOPE', 'False').lower() == 'true': - print(">> Using ModelScope") + +use_modelscope = os.getenv('MLX_USE_MODELSCOPE', 'False').lower() == 'true' +if use_modelscope: from modelscope import snapshot_download else: from huggingface_hub import snapshot_download + from mlx.utils import tree_flatten, tree_reduce from transformers import PreTrainedTokenizer @@ -158,26 +160,39 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path Path: The path to the model. """ model_path = Path(path_or_hf_repo) - print(f'>>model_path: {model_path}') - revision = revision or 'master' - print(f'>>revision: {revision}') if not model_path.exists(): try: - model_path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - revision=revision, - allow_patterns=[ - "*.json", - "*.safetensors", - "*.py", - "tokenizer.model", - "*.tiktoken", - "*.txt", - ], + if use_modelscope: + model_path = Path( + snapshot_download( + model_id=path_or_hf_repo, + revision=revision or 'master', + allow_patterns=[ + "*.json", + "*.safetensors", + "*.py", + "tokenizer.model", + "*.tiktoken", + "*.txt", + ], + ) + ) + else: + model_path = Path( + snapshot_download( + repo_id=path_or_hf_repo, + revision=revision, + allow_patterns=[ + "*.json", + "*.safetensors", + "*.py", + "tokenizer.model", + "*.tiktoken", + "*.txt", + ], + ) ) - ) except: raise ModelNotFoundError( f"Model not found for path or HF repo: {path_or_hf_repo}.\n"