mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 06:54:39 +08:00
nits
This commit is contained in:
parent
c781df114a
commit
d20181e692
@ -1,12 +1,12 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import os
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import glob
|
import glob
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -17,8 +17,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
use_modelscope = os.getenv('MLX_USE_MODELSCOPE', 'False').lower() == 'true'
|
if os.getenv("MLXLM_USE_MODELSCOPE", "False").lower() == "true":
|
||||||
if use_modelscope:
|
|
||||||
try:
|
try:
|
||||||
from modelscope import snapshot_download
|
from modelscope import snapshot_download
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -168,37 +167,20 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
|
|||||||
|
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
try:
|
try:
|
||||||
if use_modelscope:
|
model_path = Path(
|
||||||
print(f"Downloading model from Modelscope: {path_or_hf_repo}")
|
snapshot_download(
|
||||||
model_path = Path(
|
model_id=path_or_hf_repo,
|
||||||
snapshot_download(
|
revision=revision,
|
||||||
model_id=path_or_hf_repo,
|
allow_patterns=[
|
||||||
revision=revision,
|
"*.json",
|
||||||
allow_patterns=[
|
"*.safetensors",
|
||||||
"*.json",
|
"*.py",
|
||||||
"*.safetensors",
|
"tokenizer.model",
|
||||||
"*.py",
|
"*.tiktoken",
|
||||||
"tokenizer.model",
|
"*.txt",
|
||||||
"*.tiktoken",
|
],
|
||||||
"*.txt",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model_path = Path(
|
|
||||||
snapshot_download(
|
|
||||||
repo_id=path_or_hf_repo,
|
|
||||||
revision=revision,
|
|
||||||
allow_patterns=[
|
|
||||||
"*.json",
|
|
||||||
"*.safetensors",
|
|
||||||
"*.py",
|
|
||||||
"tokenizer.model",
|
|
||||||
"*.tiktoken",
|
|
||||||
"*.txt",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
raise ModelNotFoundError(
|
raise ModelNotFoundError(
|
||||||
f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
|
f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
|
||||||
|
Loading…
Reference in New Issue
Block a user