mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
fix: Added dedicated error handling to load and get_model_path (#775)
* fix: Added dedicated error handling to load and get_model_path Added proper error handling to load and get_model_path by adding a dedicated exception class, because when the local path is not right, it still throws the huggingface RepositoryNotFoundError * fix: Changed error message and resolved lack of import * fix: Removed redundant try-catch block * nits in message * nits in message --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
e92de216fd
commit
199df9e110
@ -14,6 +14,7 @@ from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from mlx.utils import tree_flatten
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
@ -33,6 +34,12 @@ MODEL_REMAPPING = {
|
||||
MAX_FILE_SIZE_GB = 5
|
||||
|
||||
|
||||
class ModelNotFoundError(Exception):
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
def _get_classes(config: dict):
|
||||
"""
|
||||
Retrieve the model and model args classes based on the configuration.
|
||||
@ -69,20 +76,29 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
|
||||
"""
|
||||
model_path = Path(path_or_hf_repo)
|
||||
if not model_path.exists():
|
||||
model_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_hf_repo,
|
||||
revision=revision,
|
||||
allow_patterns=[
|
||||
"*.json",
|
||||
"*.safetensors",
|
||||
"*.py",
|
||||
"tokenizer.model",
|
||||
"*.tiktoken",
|
||||
"*.txt",
|
||||
],
|
||||
try:
|
||||
model_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_hf_repo,
|
||||
revision=revision,
|
||||
allow_patterns=[
|
||||
"*.json",
|
||||
"*.safetensors",
|
||||
"*.py",
|
||||
"tokenizer.model",
|
||||
"*.tiktoken",
|
||||
"*.txt",
|
||||
],
|
||||
)
|
||||
)
|
||||
)
|
||||
except RepositoryNotFoundError:
|
||||
raise ModelNotFoundError(
|
||||
f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
|
||||
"Please make sure you specified the local path or Hugging Face"
|
||||
" repo id correctly.\nIf you are trying to access a private or"
|
||||
" gated Hugging Face repo, make sure you are authenticated:\n"
|
||||
"https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login"
|
||||
) from None
|
||||
return model_path
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user