diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 43f508c3..7ebfb100 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -73,6 +73,13 @@ def build_parser(): help="The path to the local model directory or Hugging Face repo.", ) + parser.add_argument( + "--revision", + default="main", + type=str, + help="Hash value of the commit to checkout from the Hugging Face repo.", + ) + # Training args parser.add_argument( "--train", @@ -252,7 +259,7 @@ def run(args, training_callback: TrainingCallback = None): np.random.seed(args.seed) print("Loading pretrained model") - model, tokenizer = load(args.model) + model, tokenizer = load(args.model, args.revision) print("Loading datasets") train_set, valid_set, test_set = load_dataset(args, tokenizer) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 0150f1b7..8e48ab25 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -152,7 +152,7 @@ def compute_bits_per_weight(model): return model_bytes * 8 / model_params -def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path: +def get_model_path(path_or_hf_repo: str, revision: Optional[str] = "main") -> Path: """ Ensures the model is available locally. If the path does not exist locally, it is downloaded from the Hugging Face Hub. @@ -184,7 +184,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path ) except: raise ModelNotFoundError( - f"Model not found for path or HF repo: {path_or_hf_repo}.\n" + f"Model not found for path or HF repo: {path_or_hf_repo}:{revision}.\n" "Please make sure you specified the local path or Hugging Face" " repo id correctly.\nIf you are trying to access a private or" " gated Hugging Face repo, make sure you are authenticated:\n" @@ -709,6 +709,7 @@ def load( model_config={}, adapter_path: Optional[str] = None, lazy: bool = False, + revision: Optional[str] = "main", ) -> Tuple[nn.Module, TokenizerWrapper]: """ Load the model and tokenizer from a given path or a huggingface repository.