mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-15 23:11:12 +08:00
reduncancy fix + nits
This commit is contained in:
parent
65a49dda0e
commit
baeb9f117f
@ -226,6 +226,39 @@ def build_parser():
|
||||
)
|
||||
return parser
|
||||
|
||||
def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_file, training_callback):
|
||||
training_args = GRPOTrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
iters=args.iters,
|
||||
val_batches=args.val_batches,
|
||||
steps_per_report=args.steps_per_report,
|
||||
steps_per_eval=args.steps_per_eval,
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
max_completion_length=args.max_completion_length,
|
||||
grad_checkpoint=args.grad_checkpoint,
|
||||
beta=args.beta,
|
||||
group_size=args.group_size,
|
||||
epsilon=args.epsilon,
|
||||
reference_model_path=args.reference_model_path
|
||||
)
|
||||
|
||||
if args.reference_model_path:
|
||||
reference_model, _ = load(args.reference_model_path)
|
||||
else:
|
||||
reference_model, _ = load(args.model)
|
||||
|
||||
train_grpo(
|
||||
model=model,
|
||||
ref_model=reference_model.freeze(),
|
||||
tokenizer=tokenizer,
|
||||
optimizer=opt,
|
||||
train_dataset=train_set,
|
||||
val_dataset=valid_set,
|
||||
args=training_args,
|
||||
training_callback=training_callback,
|
||||
)
|
||||
|
||||
def train_model(
|
||||
args,
|
||||
@ -263,19 +296,6 @@ def train_model(
|
||||
adapter_file = adapter_path / "adapters.safetensors"
|
||||
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||
|
||||
# init training args
|
||||
training_args = TrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
iters=args.iters,
|
||||
val_batches=args.val_batches,
|
||||
steps_per_report=args.steps_per_report,
|
||||
steps_per_eval=args.steps_per_eval,
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
grad_checkpoint=args.grad_checkpoint,
|
||||
)
|
||||
|
||||
model.train()
|
||||
opt = optim.Adam(
|
||||
learning_rate=(
|
||||
@ -285,37 +305,15 @@ def train_model(
|
||||
|
||||
# Train model
|
||||
if args.training_mode == "grpo":
|
||||
training_args = GRPOTrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
iters=args.iters,
|
||||
val_batches=args.val_batches,
|
||||
steps_per_report=args.steps_per_report,
|
||||
steps_per_eval=args.steps_per_eval,
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
max_completion_length=args.max_completion_length,
|
||||
grad_checkpoint=args.grad_checkpoint,
|
||||
beta=args.beta,
|
||||
group_size=args.group_size,
|
||||
epsilon=args.epsilon,
|
||||
reference_model_path=args.reference_model_path
|
||||
)
|
||||
|
||||
if args.reference_model_path:
|
||||
reference_model, _ = load(args.reference_model_path)
|
||||
else:
|
||||
reference_model, _ = load(args.model)
|
||||
|
||||
train_grpo(
|
||||
model=model,
|
||||
ref_model=reference_model.freeze(),
|
||||
tokenizer=tokenizer,
|
||||
optimizer=opt,
|
||||
train_dataset=train_set,
|
||||
val_dataset=valid_set,
|
||||
args=training_args,
|
||||
training_callback=training_callback,
|
||||
train_model_grpo(
|
||||
model,
|
||||
tokenizer,
|
||||
args,
|
||||
opt,
|
||||
train_set,
|
||||
valid_set,
|
||||
adapter_file,
|
||||
training_callback
|
||||
)
|
||||
else:
|
||||
training_args = TrainingArgs(
|
||||
|
@ -29,9 +29,7 @@ class GRPODataset:
|
||||
if use_chat_template:
|
||||
prompt_tokens = tokenizer.apply_chat_template(
|
||||
[
|
||||
{'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
||||
The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
|
||||
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
|
||||
{'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
|
||||
{'role': 'user', 'content': prompt_str}
|
||||
],
|
||||
add_generation_prompt=True
|
||||
@ -39,10 +37,7 @@ class GRPODataset:
|
||||
answer_tokens = tokenizer.encode(answer_str)
|
||||
else:
|
||||
if use_prompt:
|
||||
prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
||||
The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
|
||||
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.
|
||||
User: {prompt_str} Assistant: """)
|
||||
prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. User: {prompt_str} Assistant: """)
|
||||
else:
|
||||
prompt_tokens = tokenizer.encode(prompt_str)
|
||||
answer_tokens = tokenizer.encode(answer_str)
|
||||
|
Loading…
Reference in New Issue
Block a user