fix: Added dedicated error handling to load and get_model_path (#775)

* fix: Added dedicated error handling to load and get_model_path

Added proper error handling to load and get_model_path by adding a dedicated exception class, because when the local path is not right, it still throws the huggingface RepositoryNotFoundError

* fix: Changed error message and resolved lack of import

* fix: Removed redundant try-catch block

* nits in message

* nits in message

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
AtakanTekparmak 2024-05-20 15:39:05 +02:00 committed by GitHub
parent e92de216fd
commit 199df9e110
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -14,6 +14,7 @@ from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from huggingface_hub.utils._errors import RepositoryNotFoundError
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from transformers import AutoTokenizer, PreTrainedTokenizer from transformers import AutoTokenizer, PreTrainedTokenizer
@ -33,6 +34,12 @@ MODEL_REMAPPING = {
MAX_FILE_SIZE_GB = 5 MAX_FILE_SIZE_GB = 5
class ModelNotFoundError(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)
def _get_classes(config: dict): def _get_classes(config: dict):
""" """
Retrieve the model and model args classes based on the configuration. Retrieve the model and model args classes based on the configuration.
@ -69,20 +76,29 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
""" """
model_path = Path(path_or_hf_repo) model_path = Path(path_or_hf_repo)
if not model_path.exists(): if not model_path.exists():
model_path = Path( try:
snapshot_download( model_path = Path(
repo_id=path_or_hf_repo, snapshot_download(
revision=revision, repo_id=path_or_hf_repo,
allow_patterns=[ revision=revision,
"*.json", allow_patterns=[
"*.safetensors", "*.json",
"*.py", "*.safetensors",
"tokenizer.model", "*.py",
"*.tiktoken", "tokenizer.model",
"*.txt", "*.tiktoken",
], "*.txt",
],
)
) )
) except RepositoryNotFoundError:
raise ModelNotFoundError(
f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
"Please make sure you specified the local path or Hugging Face"
" repo id correctly.\nIf you are trying to access a private or"
" gated Hugging Face repo, make sure you are authenticated:\n"
"https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login"
) from None
return model_path return model_path