diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py
index 4751123c..4bb39832 100644
--- a/llms/mlx_lm/lora.py
+++ b/llms/mlx_lm/lora.py
@@ -226,6 +226,39 @@ def build_parser():
)
return parser
+def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_file, training_callback):
+ training_args = GRPOTrainingArgs(
+ batch_size=args.batch_size,
+ iters=args.iters,
+ val_batches=args.val_batches,
+ steps_per_report=args.steps_per_report,
+ steps_per_eval=args.steps_per_eval,
+ steps_per_save=args.save_every,
+ adapter_file=adapter_file,
+ max_seq_length=args.max_seq_length,
+ max_completion_length=args.max_completion_length,
+ grad_checkpoint=args.grad_checkpoint,
+ beta=args.beta,
+ group_size=args.group_size,
+ epsilon=args.epsilon,
+ reference_model_path=args.reference_model_path
+ )
+
+ if args.reference_model_path:
+ reference_model, _ = load(args.reference_model_path)
+ else:
+ reference_model, _ = load(args.model)
+
+ train_grpo(
+ model=model,
+ ref_model=reference_model.freeze(),
+ tokenizer=tokenizer,
+ optimizer=opt,
+ train_dataset=train_set,
+ val_dataset=valid_set,
+ args=training_args,
+ training_callback=training_callback,
+ )
def train_model(
args,
@@ -263,19 +296,6 @@ def train_model(
adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json")
- # init training args
- training_args = TrainingArgs(
- batch_size=args.batch_size,
- iters=args.iters,
- val_batches=args.val_batches,
- steps_per_report=args.steps_per_report,
- steps_per_eval=args.steps_per_eval,
- steps_per_save=args.save_every,
- adapter_file=adapter_file,
- max_seq_length=args.max_seq_length,
- grad_checkpoint=args.grad_checkpoint,
- )
-
model.train()
opt = optim.Adam(
learning_rate=(
@@ -285,37 +305,15 @@ def train_model(
# Train model
if args.training_mode == "grpo":
- training_args = GRPOTrainingArgs(
- batch_size=args.batch_size,
- iters=args.iters,
- val_batches=args.val_batches,
- steps_per_report=args.steps_per_report,
- steps_per_eval=args.steps_per_eval,
- steps_per_save=args.save_every,
- adapter_file=adapter_file,
- max_seq_length=args.max_seq_length,
- max_completion_length=args.max_completion_length,
- grad_checkpoint=args.grad_checkpoint,
- beta=args.beta,
- group_size=args.group_size,
- epsilon=args.epsilon,
- reference_model_path=args.reference_model_path
- )
-
- if args.reference_model_path:
- reference_model, _ = load(args.reference_model_path)
- else:
- reference_model, _ = load(args.model)
-
- train_grpo(
- model=model,
- ref_model=reference_model.freeze(),
- tokenizer=tokenizer,
- optimizer=opt,
- train_dataset=train_set,
- val_dataset=valid_set,
- args=training_args,
- training_callback=training_callback,
+ train_model_grpo(
+ model,
+ tokenizer,
+ args,
+ opt,
+ train_set,
+ valid_set,
+ adapter_file,
+ training_callback
)
else:
training_args = TrainingArgs(
diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py
index df5c4588..678c4f41 100644
--- a/llms/mlx_lm/tuner/datasets.py
+++ b/llms/mlx_lm/tuner/datasets.py
@@ -29,9 +29,7 @@ class GRPODataset:
if use_chat_template:
prompt_tokens = tokenizer.apply_chat_template(
[
- {'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
- The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
- The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here ."""},
+ {'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here ."""},
{'role': 'user', 'content': prompt_str}
],
add_generation_prompt=True
@@ -39,10 +37,7 @@ class GRPODataset:
answer_tokens = tokenizer.encode(answer_str)
else:
if use_prompt:
- prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
- The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
- The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here .
- User: {prompt_str} Assistant: """)
+ prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . User: {prompt_str} Assistant: """)
else:
prompt_tokens = tokenizer.encode(prompt_str)
answer_tokens = tokenizer.encode(answer_str)