first succesfull training run

This commit is contained in:
Goekdeniz-Guelmez
2025-02-04 09:18:45 +01:00
parent ca32424043
commit 7173840283
3 changed files with 68 additions and 66 deletions

View File

@@ -63,11 +63,16 @@ CONFIG_DEFAULTS = {
"config": None,
"grad_checkpoint": False,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
# GRPO args
"reference_model_path": None,
"group_size": 4,
"beta": 0.1,
"epsilon": 1e-4,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"max_completion_length": 512,
"use_chat_template": False,
"use_prompt": False,
}
@@ -178,9 +183,15 @@ def build_parser():
parser.add_argument(
"--group-size",
type=int,
help="Number of responses per prompt.",
help="Number of generations.",
default=4,
)
parser.add_argument(
"--max-completion-length",
type=int,
help="Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.",
default=512,
)
parser.add_argument(
"--beta",
type=float,
@@ -193,6 +204,18 @@ def build_parser():
help="The Epsilon for numerical stability.",
default=1e-4,
)
parser.add_argument(
"--use-chat-template",
type=bool,
help="If the model is a Chat model, use the Chat template.",
default=False,
)
parser.add_argument(
"--use-prompt",
type=bool,
help="Rather to use the prompt from teh R1 paper.",
default=False,
)
return parser
@@ -262,6 +285,7 @@ def train_model(
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
max_completion_length=args.max_completion_length,
grad_checkpoint=args.grad_checkpoint,
beta=args.beta,
group_size=args.group_size,
@@ -273,7 +297,7 @@ def train_model(
reference_model, _ = load(args.reference_model_path)
reference_model = reference_model.freeze()
else:
reference_model, _ = None, None
reference_model, _ = load(args.model)
train_grpo(
model=model,