Utilze a specific model version from HuggingFace during LoRA fine-tuning

This commit is contained in:
Sindhu Satish 2025-01-09 21:01:30 -08:00
parent 5cae0a60e6
commit 28eebfe5bf
2 changed files with 12 additions and 3 deletions

View File

@ -22,6 +22,13 @@ def build_parser():
default="mlx_model", default="mlx_model",
help="The path to the local model directory or Hugging Face repo.", help="The path to the local model directory or Hugging Face repo.",
) )
parser.add_argument(
"--revision",
default="main",
help="Specify the version of the model to use. This can be a branch name, tag, or commit hash. Defaults to 'main'.",
)
# Generation args # Generation args
parser.add_argument( parser.add_argument(
"--max-tokens", "--max-tokens",
@ -333,7 +340,7 @@ if __name__ == "__main__":
tokenizer_config["add_eos_token"] = bool(args.add_eos_token) tokenizer_config["add_eos_token"] = bool(args.add_eos_token)
print("Loading pretrained model") print("Loading pretrained model")
model, tokenizer, _ = lora_utils.load(args.model, tokenizer_config) model, tokenizer, _ = lora_utils.load(args.model, args.revision, tokenizer_config)
# Freeze all layers other than LORA linears # Freeze all layers other than LORA linears
model.freeze() model.freeze()
for l in model.model.layers[len(model.model.layers) - args.lora_layers :]: for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:

View File

@ -13,9 +13,10 @@ import transformers
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
def fetch_from_hub(hf_path: str): def fetch_from_hub(hf_path: str, revision: str = "main"):
model_path = snapshot_download( model_path = snapshot_download(
repo_id=hf_path, repo_id=hf_path,
revision=revision,
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
) )
weight_files = glob.glob(f"{model_path}/*.safetensors") weight_files = glob.glob(f"{model_path}/*.safetensors")
@ -122,7 +123,7 @@ def save_model(save_dir: str, weights, tokenizer, config):
) )
def load(path_or_hf_repo: str, tokenizer_config={}): def load(path_or_hf_repo: str, revision: str = "main", tokenizer_config={}):
# If the path exists, it will try to load model form it # If the path exists, it will try to load model form it
# otherwise download and cache from the hf_repo and cache # otherwise download and cache from the hf_repo and cache
model_path = Path(path_or_hf_repo) model_path = Path(path_or_hf_repo)
@ -130,6 +131,7 @@ def load(path_or_hf_repo: str, tokenizer_config={}):
model_path = Path( model_path = Path(
snapshot_download( snapshot_download(
repo_id=path_or_hf_repo, repo_id=path_or_hf_repo,
revision=revision,
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
) )
) )