mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
update snapshot_download
This commit is contained in:
parent
5c535a28c0
commit
2954fc56dd
@ -16,11 +16,13 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
if os.getenv('MLX_USE_MODELSCOPE', 'False').lower() == 'true':
|
|
||||||
print(">> Using ModelScope")
|
use_modelscope = os.getenv('MLX_USE_MODELSCOPE', 'False').lower() == 'true'
|
||||||
|
if use_modelscope:
|
||||||
from modelscope import snapshot_download
|
from modelscope import snapshot_download
|
||||||
else:
|
else:
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from mlx.utils import tree_flatten, tree_reduce
|
from mlx.utils import tree_flatten, tree_reduce
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
@ -158,12 +160,25 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
|
|||||||
Path: The path to the model.
|
Path: The path to the model.
|
||||||
"""
|
"""
|
||||||
model_path = Path(path_or_hf_repo)
|
model_path = Path(path_or_hf_repo)
|
||||||
print(f'>>model_path: {model_path}')
|
|
||||||
revision = revision or 'master'
|
|
||||||
print(f'>>revision: {revision}')
|
|
||||||
|
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
try:
|
try:
|
||||||
|
if use_modelscope:
|
||||||
|
model_path = Path(
|
||||||
|
snapshot_download(
|
||||||
|
model_id=path_or_hf_repo,
|
||||||
|
revision=revision or 'master',
|
||||||
|
allow_patterns=[
|
||||||
|
"*.json",
|
||||||
|
"*.safetensors",
|
||||||
|
"*.py",
|
||||||
|
"tokenizer.model",
|
||||||
|
"*.tiktoken",
|
||||||
|
"*.txt",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
model_path = Path(
|
model_path = Path(
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
repo_id=path_or_hf_repo,
|
repo_id=path_or_hf_repo,
|
||||||
|
Loading…
Reference in New Issue
Block a user