This commit is contained in:
Sindhu Satish 2025-01-29 06:01:07 -08:00 committed by GitHub
commit 651f9a5cf8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 3 deletions

View File

@ -73,6 +73,13 @@ def build_parser():
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--revision",
default="main",
type=str,
help="Hash value of the commit to checkout from the Hugging Face repo.",
)
# Training args
parser.add_argument(
"--train",
@ -252,7 +259,7 @@ def run(args, training_callback: TrainingCallback = None):
np.random.seed(args.seed)
print("Loading pretrained model")
model, tokenizer = load(args.model)
model, tokenizer = load(args.model, args.revision)
print("Loading datasets")
train_set, valid_set, test_set = load_dataset(args, tokenizer)

View File

@ -152,7 +152,7 @@ def compute_bits_per_weight(model):
return model_bytes * 8 / model_params
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = "main") -> Path:
"""
Ensures the model is available locally. If the path does not exist locally,
it is downloaded from the Hugging Face Hub.
@ -184,7 +184,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
)
except:
raise ModelNotFoundError(
f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
f"Model not found for path or HF repo: {path_or_hf_repo}:{revision}.\n"
"Please make sure you specified the local path or Hugging Face"
" repo id correctly.\nIf you are trying to access a private or"
" gated Hugging Face repo, make sure you are authenticated:\n"
@ -709,6 +709,7 @@ def load(
model_config={},
adapter_path: Optional[str] = None,
lazy: bool = False,
revision: Optional[str] = "main",
) -> Tuple[nn.Module, TokenizerWrapper]:
"""
Load the model and tokenizer from a given path or a huggingface repository.