reduncancy fix + nits

This commit is contained in:
Goekdeniz-Guelmez
2025-02-14 09:09:59 +01:00
parent 65a49dda0e
commit baeb9f117f
2 changed files with 44 additions and 51 deletions

View File

@@ -226,6 +226,39 @@ def build_parser():
)
return parser
def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_file, training_callback):
training_args = GRPOTrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
max_completion_length=args.max_completion_length,
grad_checkpoint=args.grad_checkpoint,
beta=args.beta,
group_size=args.group_size,
epsilon=args.epsilon,
reference_model_path=args.reference_model_path
)
if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model, _ = load(args.model)
train_grpo(
model=model,
ref_model=reference_model.freeze(),
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
args=training_args,
training_callback=training_callback,
)
def train_model(
args,
@@ -263,19 +296,6 @@ def train_model(
adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json")
# init training args
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
)
model.train()
opt = optim.Adam(
learning_rate=(
@@ -285,37 +305,15 @@ def train_model(
# Train model
if args.training_mode == "grpo":
training_args = GRPOTrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
max_completion_length=args.max_completion_length,
grad_checkpoint=args.grad_checkpoint,
beta=args.beta,
group_size=args.group_size,
epsilon=args.epsilon,
reference_model_path=args.reference_model_path
)
if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model, _ = load(args.model)
train_grpo(
model=model,
ref_model=reference_model.freeze(),
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
args=training_args,
training_callback=training_callback,
train_model_grpo(
model,
tokenizer,
args,
opt,
train_set,
valid_set,
adapter_file,
training_callback
)
else:
training_args = TrainingArgs(