mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
fix: Added dedicated error handling to load and get_model_path (#775)
* fix: Added dedicated error handling to load and get_model_path Added proper error handling to load and get_model_path by adding a dedicated exception class, because when the local path is not right, it still throws the huggingface RepositoryNotFoundError * fix: Changed error message and resolved lack of import * fix: Removed redundant try-catch block * nits in message * nits in message --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
e92de216fd
commit
199df9e110
@ -14,6 +14,7 @@ from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||||
|
|
||||||
@ -33,6 +34,12 @@ MODEL_REMAPPING = {
|
|||||||
MAX_FILE_SIZE_GB = 5
|
MAX_FILE_SIZE_GB = 5
|
||||||
|
|
||||||
|
|
||||||
|
class ModelNotFoundError(Exception):
|
||||||
|
def __init__(self, message):
|
||||||
|
self.message = message
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
def _get_classes(config: dict):
|
def _get_classes(config: dict):
|
||||||
"""
|
"""
|
||||||
Retrieve the model and model args classes based on the configuration.
|
Retrieve the model and model args classes based on the configuration.
|
||||||
@ -69,20 +76,29 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
|
|||||||
"""
|
"""
|
||||||
model_path = Path(path_or_hf_repo)
|
model_path = Path(path_or_hf_repo)
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
model_path = Path(
|
try:
|
||||||
snapshot_download(
|
model_path = Path(
|
||||||
repo_id=path_or_hf_repo,
|
snapshot_download(
|
||||||
revision=revision,
|
repo_id=path_or_hf_repo,
|
||||||
allow_patterns=[
|
revision=revision,
|
||||||
"*.json",
|
allow_patterns=[
|
||||||
"*.safetensors",
|
"*.json",
|
||||||
"*.py",
|
"*.safetensors",
|
||||||
"tokenizer.model",
|
"*.py",
|
||||||
"*.tiktoken",
|
"tokenizer.model",
|
||||||
"*.txt",
|
"*.tiktoken",
|
||||||
],
|
"*.txt",
|
||||||
|
],
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
except RepositoryNotFoundError:
|
||||||
|
raise ModelNotFoundError(
|
||||||
|
f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
|
||||||
|
"Please make sure you specified the local path or Hugging Face"
|
||||||
|
" repo id correctly.\nIf you are trying to access a private or"
|
||||||
|
" gated Hugging Face repo, make sure you are authenticated:\n"
|
||||||
|
"https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login"
|
||||||
|
) from None
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user