update snapshot_download

This commit is contained in:
xingjun.wang 2025-01-06 00:40:25 +08:00
parent 5c535a28c0
commit 2954fc56dd

View File

@ -16,11 +16,13 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
import mlx.core as mx
import mlx.nn as nn
if os.getenv('MLX_USE_MODELSCOPE', 'False').lower() == 'true':
print(">> Using ModelScope")
use_modelscope = os.getenv('MLX_USE_MODELSCOPE', 'False').lower() == 'true'
if use_modelscope:
from modelscope import snapshot_download
else:
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten, tree_reduce
from transformers import PreTrainedTokenizer
@ -158,26 +160,39 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
Path: The path to the model.
"""
model_path = Path(path_or_hf_repo)
print(f'>>model_path: {model_path}')
revision = revision or 'master'
print(f'>>revision: {revision}')
if not model_path.exists():
try:
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
revision=revision,
allow_patterns=[
"*.json",
"*.safetensors",
"*.py",
"tokenizer.model",
"*.tiktoken",
"*.txt",
],
if use_modelscope:
model_path = Path(
snapshot_download(
model_id=path_or_hf_repo,
revision=revision or 'master',
allow_patterns=[
"*.json",
"*.safetensors",
"*.py",
"tokenizer.model",
"*.tiktoken",
"*.txt",
],
)
)
else:
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
revision=revision,
allow_patterns=[
"*.json",
"*.safetensors",
"*.py",
"tokenizer.model",
"*.tiktoken",
"*.txt",
],
)
)
)
except:
raise ModelNotFoundError(
f"Model not found for path or HF repo: {path_or_hf_repo}.\n"