mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
fix(lora): config yaml & arg default merge bug (#1196)
This commit is contained in:
parent
b8f0cacfa8
commit
40b88eff48
@ -58,6 +58,8 @@ CONFIG_DEFAULTS = {
|
|||||||
"test": False,
|
"test": False,
|
||||||
"test_batches": 500,
|
"test_batches": 500,
|
||||||
"max_seq_length": 2048,
|
"max_seq_length": 2048,
|
||||||
|
"config": None,
|
||||||
|
"grad_checkpoint": False,
|
||||||
"lr_schedule": None,
|
"lr_schedule": None,
|
||||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||||
}
|
}
|
||||||
@ -67,6 +69,7 @@ def build_parser():
|
|||||||
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
|
type=str,
|
||||||
help="The path to the local model directory or Hugging Face repo.",
|
help="The path to the local model directory or Hugging Face repo.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -75,7 +78,6 @@ def build_parser():
|
|||||||
"--train",
|
"--train",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Do training",
|
help="Do training",
|
||||||
default=None,
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data",
|
"--data",
|
||||||
@ -89,7 +91,6 @@ def build_parser():
|
|||||||
"--fine-tune-type",
|
"--fine-tune-type",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["lora", "dora", "full"],
|
choices=["lora", "dora", "full"],
|
||||||
default="lora",
|
|
||||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -134,7 +135,6 @@ def build_parser():
|
|||||||
"--test",
|
"--test",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Evaluate on the test set after training",
|
help="Evaluate on the test set after training",
|
||||||
default=None,
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--test-batches",
|
"--test-batches",
|
||||||
@ -149,16 +149,15 @@ def build_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-c",
|
"-c",
|
||||||
"--config",
|
"--config",
|
||||||
default=None,
|
type=str,
|
||||||
help="A YAML configuration file with the training options",
|
help="A YAML configuration file with the training options",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--grad-checkpoint",
|
"--grad-checkpoint",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Use gradient checkpointing to reduce memory use.",
|
help="Use gradient checkpointing to reduce memory use.",
|
||||||
default=None,
|
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, default=None, help="The PRNG seed")
|
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user