Utilze a specific model version from HuggingFace during LoRA fine-tuning

This commit is contained in:
Sindhu Satish 2025-01-09 21:01:30 -08:00
parent 5cae0a60e6
commit 28eebfe5bf
2 changed files with 12 additions and 3 deletions

View File

@ -22,6 +22,13 @@ def build_parser():
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--revision",
default="main",
help="Specify the version of the model to use. This can be a branch name, tag, or commit hash. Defaults to 'main'.",
)
# Generation args
parser.add_argument(
"--max-tokens",
@ -333,7 +340,7 @@ if __name__ == "__main__":
tokenizer_config["add_eos_token"] = bool(args.add_eos_token)
print("Loading pretrained model")
model, tokenizer, _ = lora_utils.load(args.model, tokenizer_config)
model, tokenizer, _ = lora_utils.load(args.model, args.revision, tokenizer_config)
# Freeze all layers other than LORA linears
model.freeze()
for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:

View File

@ -13,9 +13,10 @@ import transformers
from huggingface_hub import snapshot_download
def fetch_from_hub(hf_path: str):
def fetch_from_hub(hf_path: str, revision: str = "main"):
model_path = snapshot_download(
repo_id=hf_path,
revision=revision,
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
)
weight_files = glob.glob(f"{model_path}/*.safetensors")
@ -122,7 +123,7 @@ def save_model(save_dir: str, weights, tokenizer, config):
)
def load(path_or_hf_repo: str, tokenizer_config={}):
def load(path_or_hf_repo: str, revision: str = "main", tokenizer_config={}):
# If the path exists, it will try to load model form it
# otherwise download and cache from the hf_repo and cache
model_path = Path(path_or_hf_repo)
@ -130,6 +131,7 @@ def load(path_or_hf_repo: str, tokenizer_config={}):
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
revision=revision,
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
)
)