diff --git a/lora/lora.py b/lora/lora.py index 723e783d..8e96f3b0 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -22,6 +22,13 @@ def build_parser(): default="mlx_model", help="The path to the local model directory or Hugging Face repo.", ) + + parser.add_argument( + "--revision", + default="main", + help="Specify the version of the model to use. This can be a branch name, tag, or commit hash. Defaults to 'main'.", + ) + # Generation args parser.add_argument( "--max-tokens", @@ -333,7 +340,7 @@ if __name__ == "__main__": tokenizer_config["add_eos_token"] = bool(args.add_eos_token) print("Loading pretrained model") - model, tokenizer, _ = lora_utils.load(args.model, tokenizer_config) + model, tokenizer, _ = lora_utils.load(args.model, args.revision, tokenizer_config) # Freeze all layers other than LORA linears model.freeze() for l in model.model.layers[len(model.model.layers) - args.lora_layers :]: diff --git a/lora/utils.py b/lora/utils.py index a334723c..b278179f 100644 --- a/lora/utils.py +++ b/lora/utils.py @@ -13,9 +13,10 @@ import transformers from huggingface_hub import snapshot_download -def fetch_from_hub(hf_path: str): +def fetch_from_hub(hf_path: str, revision: str = "main"): model_path = snapshot_download( repo_id=hf_path, + revision=revision, allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], ) weight_files = glob.glob(f"{model_path}/*.safetensors") @@ -122,7 +123,7 @@ def save_model(save_dir: str, weights, tokenizer, config): ) -def load(path_or_hf_repo: str, tokenizer_config={}): +def load(path_or_hf_repo: str, revision: str = "main", tokenizer_config={}): # If the path exists, it will try to load model form it # otherwise download and cache from the hf_repo and cache model_path = Path(path_or_hf_repo) @@ -130,6 +131,7 @@ def load(path_or_hf_repo: str, tokenizer_config={}): model_path = Path( snapshot_download( repo_id=path_or_hf_repo, + revision=revision, allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], ) )