This commit is contained in:
Sindhu Satish 2025-01-29 06:01:07 -08:00 committed by GitHub
commit 651f9a5cf8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 3 deletions

View File

@ -73,6 +73,13 @@ def build_parser():
help="The path to the local model directory or Hugging Face repo.", help="The path to the local model directory or Hugging Face repo.",
) )
parser.add_argument(
"--revision",
default="main",
type=str,
help="Hash value of the commit to checkout from the Hugging Face repo.",
)
# Training args # Training args
parser.add_argument( parser.add_argument(
"--train", "--train",
@ -252,7 +259,7 @@ def run(args, training_callback: TrainingCallback = None):
np.random.seed(args.seed) np.random.seed(args.seed)
print("Loading pretrained model") print("Loading pretrained model")
model, tokenizer = load(args.model) model, tokenizer = load(args.model, args.revision)
print("Loading datasets") print("Loading datasets")
train_set, valid_set, test_set = load_dataset(args, tokenizer) train_set, valid_set, test_set = load_dataset(args, tokenizer)

View File

@ -152,7 +152,7 @@ def compute_bits_per_weight(model):
return model_bytes * 8 / model_params return model_bytes * 8 / model_params
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path: def get_model_path(path_or_hf_repo: str, revision: Optional[str] = "main") -> Path:
""" """
Ensures the model is available locally. If the path does not exist locally, Ensures the model is available locally. If the path does not exist locally,
it is downloaded from the Hugging Face Hub. it is downloaded from the Hugging Face Hub.
@ -184,7 +184,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
) )
except: except:
raise ModelNotFoundError( raise ModelNotFoundError(
f"Model not found for path or HF repo: {path_or_hf_repo}.\n" f"Model not found for path or HF repo: {path_or_hf_repo}:{revision}.\n"
"Please make sure you specified the local path or Hugging Face" "Please make sure you specified the local path or Hugging Face"
" repo id correctly.\nIf you are trying to access a private or" " repo id correctly.\nIf you are trying to access a private or"
" gated Hugging Face repo, make sure you are authenticated:\n" " gated Hugging Face repo, make sure you are authenticated:\n"
@ -709,6 +709,7 @@ def load(
model_config={}, model_config={},
adapter_path: Optional[str] = None, adapter_path: Optional[str] = None,
lazy: bool = False, lazy: bool = False,
revision: Optional[str] = "main",
) -> Tuple[nn.Module, TokenizerWrapper]: ) -> Tuple[nn.Module, TokenizerWrapper]:
""" """
Load the model and tokenizer from a given path or a huggingface repository. Load the model and tokenizer from a given path or a huggingface repository.