From 199df9e1105a49a5ac064ff2c6abb7bcd1b7285c Mon Sep 17 00:00:00 2001 From: AtakanTekparmak <59488384+AtakanTekparmak@users.noreply.github.com> Date: Mon, 20 May 2024 15:39:05 +0200 Subject: [PATCH] fix: Added dedicated error handling to load and get_model_path (#775) * fix: Added dedicated error handling to load and get_model_path Added proper error handling to load and get_model_path by adding a dedicated exception class, because when the local path is not right, it still throws the huggingface RepositoryNotFoundError * fix: Changed error message and resolved lack of import * fix: Removed redundant try-catch block * nits in message * nits in message --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/utils.py | 42 +++++++++++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 305cf518..11653572 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -14,6 +14,7 @@ from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn from huggingface_hub import snapshot_download +from huggingface_hub.utils._errors import RepositoryNotFoundError from mlx.utils import tree_flatten from transformers import AutoTokenizer, PreTrainedTokenizer @@ -33,6 +34,12 @@ MODEL_REMAPPING = { MAX_FILE_SIZE_GB = 5 +class ModelNotFoundError(Exception): + def __init__(self, message): + self.message = message + super().__init__(self.message) + + def _get_classes(config: dict): """ Retrieve the model and model args classes based on the configuration. @@ -69,20 +76,29 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path """ model_path = Path(path_or_hf_repo) if not model_path.exists(): - model_path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - revision=revision, - allow_patterns=[ - "*.json", - "*.safetensors", - "*.py", - "tokenizer.model", - "*.tiktoken", - "*.txt", - ], + try: + model_path = Path( + snapshot_download( + repo_id=path_or_hf_repo, + revision=revision, + allow_patterns=[ + "*.json", + "*.safetensors", + "*.py", + "tokenizer.model", + "*.tiktoken", + "*.txt", + ], + ) ) - ) + except RepositoryNotFoundError: + raise ModelNotFoundError( + f"Model not found for path or HF repo: {path_or_hf_repo}.\n" + "Please make sure you specified the local path or Hugging Face" + " repo id correctly.\nIf you are trying to access a private or" + " gated Hugging Face repo, make sure you are authenticated:\n" + "https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login" + ) from None return model_path