fix(lora): config yaml & arg default merge bug (#1196)

This commit is contained in:
Jarrett 2025-01-09 12:33:54 -07:00 committed by GitHub
parent b8f0cacfa8
commit 40b88eff48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -58,6 +58,8 @@ CONFIG_DEFAULTS = {
"test": False, "test": False,
"test_batches": 500, "test_batches": 500,
"max_seq_length": 2048, "max_seq_length": 2048,
"config": None,
"grad_checkpoint": False,
"lr_schedule": None, "lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
} }
@ -67,6 +69,7 @@ def build_parser():
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
parser.add_argument( parser.add_argument(
"--model", "--model",
type=str,
help="The path to the local model directory or Hugging Face repo.", help="The path to the local model directory or Hugging Face repo.",
) )
@ -75,7 +78,6 @@ def build_parser():
"--train", "--train",
action="store_true", action="store_true",
help="Do training", help="Do training",
default=None,
) )
parser.add_argument( parser.add_argument(
"--data", "--data",
@ -89,7 +91,6 @@ def build_parser():
"--fine-tune-type", "--fine-tune-type",
type=str, type=str,
choices=["lora", "dora", "full"], choices=["lora", "dora", "full"],
default="lora",
help="Type of fine-tuning to perform: lora, dora, or full.", help="Type of fine-tuning to perform: lora, dora, or full.",
) )
parser.add_argument( parser.add_argument(
@ -134,7 +135,6 @@ def build_parser():
"--test", "--test",
action="store_true", action="store_true",
help="Evaluate on the test set after training", help="Evaluate on the test set after training",
default=None,
) )
parser.add_argument( parser.add_argument(
"--test-batches", "--test-batches",
@ -149,16 +149,15 @@ def build_parser():
parser.add_argument( parser.add_argument(
"-c", "-c",
"--config", "--config",
default=None, type=str,
help="A YAML configuration file with the training options", help="A YAML configuration file with the training options",
) )
parser.add_argument( parser.add_argument(
"--grad-checkpoint", "--grad-checkpoint",
action="store_true", action="store_true",
help="Use gradient checkpointing to reduce memory use.", help="Use gradient checkpointing to reduce memory use.",
default=None,
) )
parser.add_argument("--seed", type=int, default=None, help="The PRNG seed") parser.add_argument("--seed", type=int, help="The PRNG seed")
return parser return parser