mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
reduncancy fix + nits
This commit is contained in:
@@ -226,6 +226,39 @@ def build_parser():
|
||||
)
|
||||
return parser
|
||||
|
||||
def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_file, training_callback):
|
||||
training_args = GRPOTrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
iters=args.iters,
|
||||
val_batches=args.val_batches,
|
||||
steps_per_report=args.steps_per_report,
|
||||
steps_per_eval=args.steps_per_eval,
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
max_completion_length=args.max_completion_length,
|
||||
grad_checkpoint=args.grad_checkpoint,
|
||||
beta=args.beta,
|
||||
group_size=args.group_size,
|
||||
epsilon=args.epsilon,
|
||||
reference_model_path=args.reference_model_path
|
||||
)
|
||||
|
||||
if args.reference_model_path:
|
||||
reference_model, _ = load(args.reference_model_path)
|
||||
else:
|
||||
reference_model, _ = load(args.model)
|
||||
|
||||
train_grpo(
|
||||
model=model,
|
||||
ref_model=reference_model.freeze(),
|
||||
tokenizer=tokenizer,
|
||||
optimizer=opt,
|
||||
train_dataset=train_set,
|
||||
val_dataset=valid_set,
|
||||
args=training_args,
|
||||
training_callback=training_callback,
|
||||
)
|
||||
|
||||
def train_model(
|
||||
args,
|
||||
@@ -263,19 +296,6 @@ def train_model(
|
||||
adapter_file = adapter_path / "adapters.safetensors"
|
||||
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||
|
||||
# init training args
|
||||
training_args = TrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
iters=args.iters,
|
||||
val_batches=args.val_batches,
|
||||
steps_per_report=args.steps_per_report,
|
||||
steps_per_eval=args.steps_per_eval,
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
grad_checkpoint=args.grad_checkpoint,
|
||||
)
|
||||
|
||||
model.train()
|
||||
opt = optim.Adam(
|
||||
learning_rate=(
|
||||
@@ -285,37 +305,15 @@ def train_model(
|
||||
|
||||
# Train model
|
||||
if args.training_mode == "grpo":
|
||||
training_args = GRPOTrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
iters=args.iters,
|
||||
val_batches=args.val_batches,
|
||||
steps_per_report=args.steps_per_report,
|
||||
steps_per_eval=args.steps_per_eval,
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
max_completion_length=args.max_completion_length,
|
||||
grad_checkpoint=args.grad_checkpoint,
|
||||
beta=args.beta,
|
||||
group_size=args.group_size,
|
||||
epsilon=args.epsilon,
|
||||
reference_model_path=args.reference_model_path
|
||||
)
|
||||
|
||||
if args.reference_model_path:
|
||||
reference_model, _ = load(args.reference_model_path)
|
||||
else:
|
||||
reference_model, _ = load(args.model)
|
||||
|
||||
train_grpo(
|
||||
model=model,
|
||||
ref_model=reference_model.freeze(),
|
||||
tokenizer=tokenizer,
|
||||
optimizer=opt,
|
||||
train_dataset=train_set,
|
||||
val_dataset=valid_set,
|
||||
args=training_args,
|
||||
training_callback=training_callback,
|
||||
train_model_grpo(
|
||||
model,
|
||||
tokenizer,
|
||||
args,
|
||||
opt,
|
||||
train_set,
|
||||
valid_set,
|
||||
adapter_file,
|
||||
training_callback
|
||||
)
|
||||
else:
|
||||
training_args = TrainingArgs(
|
||||
|
||||
Reference in New Issue
Block a user