This commit is contained in:
Awni Hannun 2025-01-10 15:23:27 -08:00
parent c781df114a
commit d20181e692

View File

@ -1,12 +1,12 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import os
import contextlib import contextlib
import copy import copy
import glob import glob
import importlib import importlib
import json import json
import logging import logging
import os
import shutil import shutil
import time import time
from dataclasses import dataclass from dataclasses import dataclass
@ -17,8 +17,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
use_modelscope = os.getenv('MLX_USE_MODELSCOPE', 'False').lower() == 'true' if os.getenv("MLXLM_USE_MODELSCOPE", "False").lower() == "true":
if use_modelscope:
try: try:
from modelscope import snapshot_download from modelscope import snapshot_download
except ImportError: except ImportError:
@ -168,37 +167,20 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
if not model_path.exists(): if not model_path.exists():
try: try:
if use_modelscope: model_path = Path(
print(f"Downloading model from Modelscope: {path_or_hf_repo}") snapshot_download(
model_path = Path( model_id=path_or_hf_repo,
snapshot_download( revision=revision,
model_id=path_or_hf_repo, allow_patterns=[
revision=revision, "*.json",
allow_patterns=[ "*.safetensors",
"*.json", "*.py",
"*.safetensors", "tokenizer.model",
"*.py", "*.tiktoken",
"tokenizer.model", "*.txt",
"*.tiktoken", ],
"*.txt",
],
)
)
else:
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
revision=revision,
allow_patterns=[
"*.json",
"*.safetensors",
"*.py",
"tokenizer.model",
"*.tiktoken",
"*.txt",
],
)
) )
)
except: except:
raise ModelNotFoundError( raise ModelNotFoundError(
f"Model not found for path or HF repo: {path_or_hf_repo}.\n" f"Model not found for path or HF repo: {path_or_hf_repo}.\n"