This commit is contained in:
Awni Hannun 2025-01-10 15:23:27 -08:00
parent c781df114a
commit d20181e692

View File

@ -1,12 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
import os
import contextlib
import copy
import glob
import importlib
import json
import logging
import os
import shutil
import time
from dataclasses import dataclass
@ -17,8 +17,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
import mlx.core as mx
import mlx.nn as nn
use_modelscope = os.getenv('MLX_USE_MODELSCOPE', 'False').lower() == 'true'
if use_modelscope:
if os.getenv("MLXLM_USE_MODELSCOPE", "False").lower() == "true":
try:
from modelscope import snapshot_download
except ImportError:
@ -168,37 +167,20 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
if not model_path.exists():
try:
if use_modelscope:
print(f"Downloading model from Modelscope: {path_or_hf_repo}")
model_path = Path(
snapshot_download(
model_id=path_or_hf_repo,
revision=revision,
allow_patterns=[
"*.json",
"*.safetensors",
"*.py",
"tokenizer.model",
"*.tiktoken",
"*.txt",
],
)
)
else:
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
revision=revision,
allow_patterns=[
"*.json",
"*.safetensors",
"*.py",
"tokenizer.model",
"*.tiktoken",
"*.txt",
],
)
model_path = Path(
snapshot_download(
model_id=path_or_hf_repo,
revision=revision,
allow_patterns=[
"*.json",
"*.safetensors",
"*.py",
"tokenizer.model",
"*.tiktoken",
"*.txt",
],
)
)
except:
raise ModelNotFoundError(
f"Model not found for path or HF repo: {path_or_hf_repo}.\n"