reduncancy fix + nits

This commit is contained in:
Goekdeniz-Guelmez 2025-02-14 09:09:59 +01:00
parent 65a49dda0e
commit baeb9f117f
2 changed files with 44 additions and 51 deletions

View File

@ -226,6 +226,39 @@ def build_parser():
) )
return parser return parser
def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_file, training_callback):
training_args = GRPOTrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
max_completion_length=args.max_completion_length,
grad_checkpoint=args.grad_checkpoint,
beta=args.beta,
group_size=args.group_size,
epsilon=args.epsilon,
reference_model_path=args.reference_model_path
)
if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model, _ = load(args.model)
train_grpo(
model=model,
ref_model=reference_model.freeze(),
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
args=training_args,
training_callback=training_callback,
)
def train_model( def train_model(
args, args,
@ -263,19 +296,6 @@ def train_model(
adapter_file = adapter_path / "adapters.safetensors" adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json") save_config(vars(args), adapter_path / "adapter_config.json")
# init training args
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
)
model.train() model.train()
opt = optim.Adam( opt = optim.Adam(
learning_rate=( learning_rate=(
@ -285,37 +305,15 @@ def train_model(
# Train model # Train model
if args.training_mode == "grpo": if args.training_mode == "grpo":
training_args = GRPOTrainingArgs( train_model_grpo(
batch_size=args.batch_size, model,
iters=args.iters, tokenizer,
val_batches=args.val_batches, args,
steps_per_report=args.steps_per_report, opt,
steps_per_eval=args.steps_per_eval, train_set,
steps_per_save=args.save_every, valid_set,
adapter_file=adapter_file, adapter_file,
max_seq_length=args.max_seq_length, training_callback
max_completion_length=args.max_completion_length,
grad_checkpoint=args.grad_checkpoint,
beta=args.beta,
group_size=args.group_size,
epsilon=args.epsilon,
reference_model_path=args.reference_model_path
)
if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model, _ = load(args.model)
train_grpo(
model=model,
ref_model=reference_model.freeze(),
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
args=training_args,
training_callback=training_callback,
) )
else: else:
training_args = TrainingArgs( training_args = TrainingArgs(

View File

@ -29,9 +29,7 @@ class GRPODataset:
if use_chat_template: if use_chat_template:
prompt_tokens = tokenizer.apply_chat_template( prompt_tokens = tokenizer.apply_chat_template(
[ [
{'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. {'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
{'role': 'user', 'content': prompt_str} {'role': 'user', 'content': prompt_str}
], ],
add_generation_prompt=True add_generation_prompt=True
@ -39,10 +37,7 @@ class GRPODataset:
answer_tokens = tokenizer.encode(answer_str) answer_tokens = tokenizer.encode(answer_str)
else: else:
if use_prompt: if use_prompt:
prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. User: {prompt_str} Assistant: """)
The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.
User: {prompt_str} Assistant: """)
else: else:
prompt_tokens = tokenizer.encode(prompt_str) prompt_tokens = tokenizer.encode(prompt_str)
answer_tokens = tokenizer.encode(answer_str) answer_tokens = tokenizer.encode(answer_str)