From 2954fc56dd9adf11aaa27d25f0455a7b3408c18c Mon Sep 17 00:00:00 2001 From: "xingjun.wang" Date: Mon, 6 Jan 2025 00:40:25 +0800 Subject: [PATCH] update snapshot_download --- llms/mlx_lm/utils.py | 51 ++++++++++++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 18 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index e604b09c..670af8f6 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -16,11 +16,13 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, import mlx.core as mx import mlx.nn as nn -if os.getenv('MLX_USE_MODELSCOPE', 'False').lower() == 'true': - print(">> Using ModelScope") + +use_modelscope = os.getenv('MLX_USE_MODELSCOPE', 'False').lower() == 'true' +if use_modelscope: from modelscope import snapshot_download else: from huggingface_hub import snapshot_download + from mlx.utils import tree_flatten, tree_reduce from transformers import PreTrainedTokenizer @@ -158,26 +160,39 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path Path: The path to the model. """ model_path = Path(path_or_hf_repo) - print(f'>>model_path: {model_path}') - revision = revision or 'master' - print(f'>>revision: {revision}') if not model_path.exists(): try: - model_path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - revision=revision, - allow_patterns=[ - "*.json", - "*.safetensors", - "*.py", - "tokenizer.model", - "*.tiktoken", - "*.txt", - ], + if use_modelscope: + model_path = Path( + snapshot_download( + model_id=path_or_hf_repo, + revision=revision or 'master', + allow_patterns=[ + "*.json", + "*.safetensors", + "*.py", + "tokenizer.model", + "*.tiktoken", + "*.txt", + ], + ) + ) + else: + model_path = Path( + snapshot_download( + repo_id=path_or_hf_repo, + revision=revision, + allow_patterns=[ + "*.json", + "*.safetensors", + "*.py", + "tokenizer.model", + "*.tiktoken", + "*.txt", + ], + ) ) - ) except: raise ModelNotFoundError( f"Model not found for path or HF repo: {path_or_hf_repo}.\n"